Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extension of all_gather #5392

Open
mganahl opened this issue Jan 13, 2021 · 2 comments
Open

extension of all_gather #5392

mganahl opened this issue Jan 13, 2021 · 2 comments
Assignees
Labels
needs info More information is required to diagnose & prioritize the issue. P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@mganahl
Copy link

mganahl commented Jan 13, 2021

Hi,
I wanted propose a feature extension for all_gather. Our team is currently running test-calculations on TPU pods, and we use a custom broadcasting operation build in JAX that might also be useful for others.

We'd like to propose the addition of an optional source_indices argument to all_gather that would allow
gathering values from a particular source index to all indices within one element in axis_index_groups, i.e.

A = np.arange(4)
print(pmap(lambda x: jax.lax.all_gather(x, 
   axis_name='i', 
   axis_index_groups=[[0,2],[1,3]], 
   source_indices=[0,3])) #[0,3,0,3]

This is a minimal implementation that achieves this:

def mask(A, cond):
    do_not_mask = jnp.zeros_like(A, dtype=np.bool) + cond
    return jnp.where(do_not_mask, x=A, y=jnp.zeros_like(A))

 
def all_gather(x, axis_name, axis_index_groups=None, source_indices=None):
    if source_indices is None:
        return jax.lax.all_gather(x, axis_name=axis_name, 
                                  axis_index_groups=axis_index_groups)
    assert len(source_indices) == len(axis_index_groups)
    for n, source in enumerate(source_indices):
        assert source in axis_index_groups[n]
        
    axis_index = jax.lax.axis_index(axis_name)
    keep =jax.numpy.any(jnp.array([axis_index == s for s in source_indices]))
    masked =  mask(x, keep)
    return jax.lax.psum(masked,axis_name,axis_index_groups=axis_index_groups)
    
    
    
@hawkinsp hawkinsp removed their assignment May 14, 2021
@apaszke
Copy link
Collaborator

apaszke commented Jun 2, 2021

So it's really more like a broadcast from a single device to all other devices in its axis index group? I don't think that fits very well with the semantics of all_gather, so it should at least be a separate function. But, perhaps more importantly, I think that we should start moving away from axis_index_groups, in favor of xmap. So I think that we might be interested in merging something similar to what you propose, but only if we can make sense of it in the world of multidimensional meshes and named axes. Do you think that makes sense?

@apaszke apaszke self-assigned this Jun 2, 2021
@apaszke apaszke added P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR needs info More information is required to diagnose & prioritize the issue. labels Jun 2, 2021
@chaserileyroberts
Copy link
Contributor

I believe once this is accepted into stablehlo, it should be exactly what you want. I plan on adding this to Jax under lax.pcollective_broadcast once this RFC is accepted.

openxla/stablehlo#1809

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs info More information is required to diagnose & prioritize the issue. P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

4 participants