Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SPMD sharding of fft ops. #24

Open
hawkinsp opened this issue Nov 14, 2022 · 4 comments
Open

Support SPMD sharding of fft ops. #24

hawkinsp opened this issue Nov 14, 2022 · 4 comments
Labels
enhancement New feature or request GPU XLA on GPU NVIDIA-GPU XLA on Nvidia GPU

Comments

@hawkinsp
Copy link
Member

In jax-ml/jax#13081 we found that XLA doesn't support SPMD sharding of fast-fourier transform ops. It should!

@cheshire
Copy link
Contributor

Tracked internally in b/263023739

@jon-chuang
Copy link

jon-chuang commented Apr 6, 2023

1. Dim 1 FFT: To simplify the problem, consider when the number of shards and FFT size are powers of 2, i.e. num_shards=2^k. - map-reduce style: Do a size `N/num_shards` FFT on each device. Subsequently, there is a shuffle per stage. 2. For dim > 1 FFT, The problem actually becomes simpler: if N % num_shards == 0, one can perform the FFTs along the row dim, and then do a shuffle, performing the FFTs in parallel along the column dim. (see: http://dsp-book.narod.ru/FFTBB/0270_PDF_C23.pdf)

So we would have to restrict to certain choices of $N$ and num_shards

Oh wait, this is parallelizing single FFT... smpd sharding (pmap) should be a lot simpler.

@jon-chuang
Copy link

jon-chuang commented Apr 6, 2023

Wait, there seems to be previous work on SPMD for FFT: 180dbd6.
Seems it only supports rank >=3:

if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) {

And it does indeed do a sharded parallel FFT...

@jon-chuang
Copy link

Could you clarify what is meant by SPMD in the case of FFTs @hawkinsp @cheshire ?

Does this mean running multiple FFTs on tensors with rank >= 2, along a batch dimension? Or does it mean parallelising a single FFT across the FFT's ranks?

copybara-service bot pushed a commit that referenced this issue Oct 24, 2023
#MIGRATION_3P_TRITON__GIT_TO_THIRD_PARTY

# Commits integrated

  - 726bdb984f2bcb48adfaa341ee7b0263be227b98 [FRONTEND][BACKEND] Fix constexpr assignment ; revert #24... by Zahi Moudallal <128723247+zahimoud@users.noreply.github.com>
  - 87a223d76fe32a28ca563c94215a95f505794c6d bump triton_shared (#2501) by Maksim Levental <maksim.levental@gmail.com>
  - 721897fcc4f942aa97d2e9ba3787a5e213758177 upgrade llvm to `b1115f8c` (NFC) (#2403) by Mehdi Amini <mamini@nvidia.com>
  - 05dc28be0e72dd496300a31b99a21a5a5118f8e9 [CI] refactor workflows (#2504) by Philippe Tillet <phil@openai.com>
  - 376acb610b5888263ee61713ff0a71e1d5908d69 [BUILD] Fix macos x86 build (#2505) by Thomas Raoux <thomas.raoux@openai.com>
  - 768fc1fcd98ecfc0892f8982b0bb009dd7bb11ea [FRONTEND] change hash to not require ptxas (#2476) by ian Bearman <ianb@microsoft.com>
  - e36d1665ca2f816212fc80ee2633caa66a0066bf [BACKEND] Fix unsupported view op created during optimiza... by Thomas Raoux <thomas.raoux@openai.com>
  - a980ec50f1ed3176e2603c25f73f0ddc031cf1d8 [BACKEND] Fixing f8e5m2 to bf16 conversion on A100 (#2508) by Zahi Moudallal <128723247+zahimoud@users.noreply.github.com>
  - a4f373938c9a4ba67105c5394c168945af4c990e [RUNTIME] Filter out paths that don't exist in json group... by Horace He <chilli@meta.com>
  - be1de890e1f9bdf0910521b5a536c332a1c1aa2f [BACKEND] Replace assert(0) with llvm::report_fatal_error... by Keren Zhou <kerenzhou@openai.com>
  - 0d57820be9ca360cf62cc3a7dc21aecc45a1c53a update triton-shared ref (#2506) by ian Bearman <ianb@microsoft.com>
  - bdf464e4a8f80ad6bd6a7b470cb3d36efd61c8a2 Make kernel_static_print test work when called twice. (#2... by Justin Lebar <justin.lebar@gmail.com>
  - 30186f401ec52d9addac79a60f418792875f7d11 Fix segfault in assertion test. (#2520) by Justin Lebar <justin.lebar@gmail.com>
  - dc9e3063d73d2410e1855e1ff258aa90a6158548 [HOPPER] Move to tl.make_block_ptr in flash_attention bac... by runseny <145632023+runseny@users.noreply.github.com>
  - b0c166b9e3f2f58c0906fa41f261787ebf3fef0d [BACKEND] Fixing bug in elementwise conversion (#2517) by Zahi Moudallal <128723247+zahimoud@users.noreply.github.com>
  - 4f4c07e7d586aae3daa802ce86a9aa935f8cda17 [CI] add text file containing LLVM commit hash by Ashay Rane <ashay@users.noreply.github.com>
  - 7af27fadee0fce2218a1353feea2f76ea25ad005 update hash to 76ce4736721a by Phil Tillet <phil@openai.com>
  - f192611ff3bdacb8d1d1cad084dfe4cd277a0ec9 Bump LLVM version to https://github.com/llvm/llvm-project... by Goran Flegar <gflegar@google.com>

PiperOrigin-RevId: 576212898
@penpornk penpornk added enhancement New feature or request NVIDIA-GPU XLA on Nvidia GPU GPU XLA on GPU labels Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request GPU XLA on GPU NVIDIA-GPU XLA on Nvidia GPU
Projects
None yet
Development

No branches or pull requests

4 participants