Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairwise summation takes very long time #1918

Closed
oseledets opened this issue Dec 26, 2019 · 6 comments
Closed

Pairwise summation takes very long time #1918

oseledets opened this issue Dec 26, 2019 · 6 comments
Assignees
Labels
question Questions for the JAX team

Comments

@oseledets
Copy link

oseledets commented Dec 26, 2019

I am trying to compute a very simple sum on the TPU in Google Colaboratory, but the jitted code goes into a very long (>10 minutes) computational cycle.

@jit
def potential(r):
       U = 4.*(np.power(r,-12) - np.power(r,-6))
       return U

@jit
def total_energy(x):
   E = 0.
   #need to speed up this part
   for i in range(N):                                    #To N here
     for j in range(i):
         E += potential(np.sqrt(np.sum((x[i]-x[j])**2))) #Add sum here
   return E

and test it:

N = 1000
key = jax.random.PRNGKey(0)
points = jax.random.normal(key, (N,3)) +0.1

%timeit total_energy(points)

What can be the reason? The notebook can be found here: https://colab.research.google.com/drive/12xuBBs7-P3JiHGclffBPhH5Xf2EdImKn

@jotsif
Copy link

jotsif commented Dec 26, 2019

Hi!

I think you should use lax loop primitives instead of python loops because of the issue of unrolling those loops. See this issue for example #402

@mattjj
Copy link
Collaborator

mattjj commented Dec 27, 2019

Thanks, @jotsif. I think you're right about the underlying issue.

To say a bit more, this program is probably taking a long time to compile because those Python loops are getting unrolled yielding a large XLA program. (I'd expect that, after it's compiled once, it'll be very fast to execute!)

One option is to write these loops using loop primitives that get translated directly to XLA loop constructs (rather than being unrolled), keeping the XLA program small and hence compile times short. You can use lax.fori_loop or lax.scan for that (the latter being reverse-mode differentiable). (There's also an experimental way to embed such loops more conveniently.) Nested loops get a little awkward though:

from jax.lax import fori_loop

@jit
def total_energy(x):
  N = x.shape[0]
  return fori_loop(0, N, lambda i, E:
                   fori_loop(0, i, lambda j, E:
                             E + potential(np.sqrt(np.sum((x[i]-x[j])**2))),
                             E),
                   0.)
%timeit -n 3 -r 3 total_energy(points).block_until_ready()
3 loops, best of 3: 560 ms per loop

(See the docs on asynchronous dispatch for an explanation of why you need block_until_ready() when performing microbenchmarks.)

That's pretty slow, much slower than CPU execution. I'm not sure what's going on. I'll raise it with our TPU experts next week.

However, if this example is similar to the computation you actually care about, you can write it in a vectorized way in terms of NumPy primitives rather than using any explicit loops:

@jit
def potential(r):
  U = 4.*(np.power(r,-12) - np.power(r,-6))
  return U

@jit
def total_energy(x):
  distances = np.sqrt(np.sum((x[:, None, :] - x[None, :, :])**2, axis=-1))
  return np.sum(np.tril(potential(distances), k=-1))
%timeit -n 10 -r 10 total_energy(points).block_until_ready()
10 loops, best of 10: 958 µs per loop

Is that style an option for you?

@mattjj mattjj added the question Questions for the JAX team label Dec 27, 2019
@mattjj mattjj self-assigned this Dec 27, 2019
@mattjj
Copy link
Collaborator

mattjj commented Dec 27, 2019

You can also compute Euclidean pairwise distances using a polarization identity but in my brief tests that seemed less numerically stable here and not faster.

@mattjj
Copy link
Collaborator

mattjj commented Jan 7, 2020

I'm going to close this issue because it's not active, and because I think we covered the relevant stuff. Please re-open (or open a new issue) if that's not the case!

@mattjj mattjj closed this as completed Jan 7, 2020
@tlh24
Copy link

tlh24 commented May 7, 2022

I'm also trying to efficiently calculate pairwise distances with large matrices. The problem now, however, is that JAX seems to always run out of memory.

import jax
import jax.numpy as jnp

def distance(x, siz):
	@jax.jit
	def l2_dist2(m, ai, bi):
		return jax.lax.fori_loop(0, siz, \
			lambda i,c: c+(m[ai,i]-m[bi,i])**2, 0.0)
	def l2_dist_cond(m, ai, bi):
		return jax.lax.cond(bi > ai, l2_dist2, lambda a,b,c: 0.0, m,ai,bi)
	def l2_dist3(m, ind, vi):
		return jax.vmap(l2_dist_cond, (None, 0, None), 0)(m, ind, vi)

	ind1 = jnp.arange(0,siz)
	ind2 = jnp.arange(0,siz)
	dist = jax.vmap(l2_dist3, (None, None, 0), 0)(x, ind1, ind2)

	return dist + jnp.transpose(dist)

seed = 17016
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)
x = jax.random.uniform(subkey, (16384, 7000))
d = distance(x, 16384)

No matter how i write it, JAX allocates a 16k x 16k x 7k tensor, when (it seems to me) it should allocate a 16k x 16k matrix. Hmm? (This is almost certainly not the best way to write it!)

In the meantime, I shuttle data from jax to pytorch and use the cdist function there, which works quickly, minus the translation time.

@kheyer
Copy link

kheyer commented Aug 5, 2022

Is there a Jax preferred way to run computations that result in a matrix/tensor of outputs? I've run into some of the issues described here.

For a toy example, say we want to calculate pairwise distances between two vectors of size v1=(n,d) and v2=(m,d) by computing the squared euclidean distance for each vector pair in v1 and v2. (there are more efficient ways of doing this, but this method illustrates the point).

We start with:

def compute_distances(v1, v2):
    return np.power((v1[:,None,:] - v2[None,:,:]), 2).sum(-1)

A problem with this method is numpy allocates a full (n,n,d) tensor during the computation. Running this on my machine with inputs of size v1=(3000,768) and v2=(3000,768) takes 57 seconds and has a memory peak of 52 GB!

One way around this is to pre-allocate the output array, loop over one of the inputs to minimize memory allocation, and use Numba to speed things up

@numba.jit
def compute_distances(v1, v2):
    rows = v1.shape[0]
    cols = v2.shape[0]
    output = np.zeros((rows, cols), dtype=v1.dtype)
    
    for row in range(rows):
        distance = np.power(v1[row] * v2, 2).sum(-1)
        output[row] = distance
                
    return output

This method runs the above test in 10 seconds with negligible memory usage.

Is there a similar approach for Jax? Below is the standard (as I understand it) Jax approach:

def compute_pair_distance(a, b):
    return jnp.power(a*b, 2).sum(-1)

compute_distances = jax.vmap(jax.vmap(compute_pair_distance, (None,0), 0), (0,None), 0)
distances = compute_distances(v1, v2)

This has similar runtime to numba - about 10 seconds - but has a huge memory memory peak of 80 GB. Is there a way to do this sort of computation in Jax where the output array is pre-allocated and filled in? I haven't been able to figure out a way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

5 participants