-
Notifications
You must be signed in to change notification settings - Fork 438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support SPMD sharding of fft
ops.
#24
Labels
Comments
Tracked internally in b/263023739 |
1. Dim 1 FFT: To simplify the problem, consider when the number of shards and FFT size are powers of 2, i.e. num_shards=2^k.
- map-reduce style: Do a size `N/num_shards` FFT on each device. Subsequently, there is a shuffle per stage.
2. For dim > 1 FFT, The problem actually becomes simpler: if N % num_shards == 0, one can perform the FFTs along the row dim, and then do a shuffle, performing the FFTs in parallel along the column dim. (see: http://dsp-book.narod.ru/FFTBB/0270_PDF_C23.pdf)
Oh wait, this is parallelizing single FFT... smpd sharding (pmap) should be a lot simpler.
So we would have to restrict to certain choices of |
Wait, there seems to be previous work on SPMD for FFT: 180dbd6. xla/xla/service/spmd/fft_handler.cc Line 350 in 7b562aa
And it does indeed do a sharded parallel FFT... |
copybara-service bot
pushed a commit
that referenced
this issue
Oct 24, 2023
#MIGRATION_3P_TRITON__GIT_TO_THIRD_PARTY # Commits integrated - 726bdb984f2bcb48adfaa341ee7b0263be227b98 [FRONTEND][BACKEND] Fix constexpr assignment ; revert #24... by Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> - 87a223d76fe32a28ca563c94215a95f505794c6d bump triton_shared (#2501) by Maksim Levental <maksim.levental@gmail.com> - 721897fcc4f942aa97d2e9ba3787a5e213758177 upgrade llvm to `b1115f8c` (NFC) (#2403) by Mehdi Amini <mamini@nvidia.com> - 05dc28be0e72dd496300a31b99a21a5a5118f8e9 [CI] refactor workflows (#2504) by Philippe Tillet <phil@openai.com> - 376acb610b5888263ee61713ff0a71e1d5908d69 [BUILD] Fix macos x86 build (#2505) by Thomas Raoux <thomas.raoux@openai.com> - 768fc1fcd98ecfc0892f8982b0bb009dd7bb11ea [FRONTEND] change hash to not require ptxas (#2476) by ian Bearman <ianb@microsoft.com> - e36d1665ca2f816212fc80ee2633caa66a0066bf [BACKEND] Fix unsupported view op created during optimiza... by Thomas Raoux <thomas.raoux@openai.com> - a980ec50f1ed3176e2603c25f73f0ddc031cf1d8 [BACKEND] Fixing f8e5m2 to bf16 conversion on A100 (#2508) by Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> - a4f373938c9a4ba67105c5394c168945af4c990e [RUNTIME] Filter out paths that don't exist in json group... by Horace He <chilli@meta.com> - be1de890e1f9bdf0910521b5a536c332a1c1aa2f [BACKEND] Replace assert(0) with llvm::report_fatal_error... by Keren Zhou <kerenzhou@openai.com> - 0d57820be9ca360cf62cc3a7dc21aecc45a1c53a update triton-shared ref (#2506) by ian Bearman <ianb@microsoft.com> - bdf464e4a8f80ad6bd6a7b470cb3d36efd61c8a2 Make kernel_static_print test work when called twice. (#2... by Justin Lebar <justin.lebar@gmail.com> - 30186f401ec52d9addac79a60f418792875f7d11 Fix segfault in assertion test. (#2520) by Justin Lebar <justin.lebar@gmail.com> - dc9e3063d73d2410e1855e1ff258aa90a6158548 [HOPPER] Move to tl.make_block_ptr in flash_attention bac... by runseny <145632023+runseny@users.noreply.github.com> - b0c166b9e3f2f58c0906fa41f261787ebf3fef0d [BACKEND] Fixing bug in elementwise conversion (#2517) by Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> - 4f4c07e7d586aae3daa802ce86a9aa935f8cda17 [CI] add text file containing LLVM commit hash by Ashay Rane <ashay@users.noreply.github.com> - 7af27fadee0fce2218a1353feea2f76ea25ad005 update hash to 76ce4736721a by Phil Tillet <phil@openai.com> - f192611ff3bdacb8d1d1cad084dfe4cd277a0ec9 Bump LLVM version to https://github.com/llvm/llvm-project... by Goran Flegar <gflegar@google.com> PiperOrigin-RevId: 576212898
penpornk
added
enhancement
New feature or request
NVIDIA-GPU
XLA on Nvidia GPU
GPU
XLA on GPU
labels
Feb 29, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
In jax-ml/jax#13081 we found that XLA doesn't support SPMD sharding of fast-fourier transform ops. It should!
The text was updated successfully, but these errors were encountered: