extension of all_gather #5392
Labels
needs info
More information is required to diagnose & prioritize the issue.
P3 (no schedule)
We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Hi,
I wanted propose a feature extension for
all_gather
. Our team is currently running test-calculations on TPU pods, and we use a custom broadcasting operation build in JAX that might also be useful for others.We'd like to propose the addition of an optional
source_indices
argument toall_gather
that would allowgathering values from a particular source index to all indices within one element in
axis_index_groups
, i.e.This is a minimal implementation that achieves this:
The text was updated successfully, but these errors were encountered: