Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[BUG] FFT - unhandled instruction error #713

Closed
samblouir opened this issue Sep 23, 2022 · 4 comments
Closed

[BUG] FFT - unhandled instruction error #713

samblouir opened this issue Sep 23, 2022 · 4 comments

Comments

@samblouir
Copy link

samblouir commented Sep 23, 2022

Please describe the bug

If parallelize shard parallel encounters Jax's FFT functions, the program crashes with an unhandled instruction error.

Please describe the expected behavior
The function runs FFT on the GPU

System information and environment

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): RHEL 8 using Singularity -> Ubuntu 20.04
  • Python version: 3.9.9
  • CUDA version: 11.7
  • NCCL version: 2.13.4-1
  • cupy version: cupy-cuda115==10.6.0
  • GPU model and memory: 4x Nvidia A100-AXM 80GB
  • Alpa version: 0.2.0
  • TensorFlow version: 2.9.1
  • JAX version: 0.3.15, jaxlib==0.3.15+cuda113.cudnn820

To Reproduce
Steps to reproduce the behavior:

  1. Create a model that uses jax's or numpy's FFT, RFFT, etc.... alpa.value_and_grad and jax.value_and_grad crash the same way.
  2. The program will crash with an unhandled instruction error when alpa.parallelize is being evaluated

Screenshots
If applicable, add screenshots to help explain your problem.

F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc:620] Unhandled instruction: %fft.171 = c64[128,128,4096]{2,1,0} fft(c64[128,128,4096]{1,2,0} %transpose.78), fft_type=FFT, fft_length={4096}, metadata={op_name="parallelize(update_fn_shard_parallel)/jit(main)/jit(update_fn)/jvp(alpaBug)/VmapModel0/...)/jit(fft)/jit(fft)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4096,)]" source_file="/user/sblouir/alpaBug/ft_bug.py" source_line=393}
starter.sh: line 30: 1591031 Aborted (core dumped) $@

Code snippet to reproduce the problem

It's really just these FFT functions anywhere in the parallelize path that cause the issue. I can write a real code snippet, if needed.

def apply_fn(x):
....return jnp.fft.fft(x)

@alpa.parallelize
def train_step(model_state, batch, apply_fn):
....def loss_func(params):
........out = apply_fn(params, batch["x"])
........return jnp.mean((out - batch["y"]) ** 2)

....grads = grad(loss_func)(model_state.params)
....new_model_state = model_state.apply_gradient(grads)
....return new_model_state

model_state = create_train_state()
for batch in data_loader:
....model_state = train_step(model_state, batch)

Additional information
jax.lax.cond crashes similarly. Not sure if this FFT implementation has this underneath.
Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.

@merrymercy
Copy link
Member

merrymercy commented Sep 23, 2022

Unfortunately, auto parallelization of FFT and control flow is not supported.
We are working on supporting control flow (#400), but FFT is not in our current plan and needs non-trivial effort.

What's your use case?

@samblouir
Copy link
Author

samblouir commented Sep 25, 2022

Hi,

Thanks for the reply. I hope you don't mind, but I found a way to use a rough FFT with Alpa and posted code with instructions here, in case someone else wants to use FFT with Alpa. It's just some functions that compute the FFT using Jax, without calling the built-in function.

It seems to be working fine for training a model, but when I try compiling with long sequences (16k+ length), it seems to compile indefinitely. Is there a way to manually mark a function to not be sliced? Maybe marking the FFT function to stay on one device would allow the program to compile, but It is working just fine with a length of 8,192.

I was using it for FFT Convolutions. On long sequences, this is much faster than applying a convolution normally. cuDNN should automatically detect this in certain situations and compile differently, but I was never able to get it to finish compiling when I let it try a long sequence with a normal convolution.

Thanks

@merrymercy
Copy link
Member

Good to hear you make it work. The code looks interesting.
Do you know why it compiles indefinitely? Is it because it takes too much memory or is it because the program is too long? Do you know which line the program hangs at? Can this be reproduced by your script?

@merrymercy
Copy link
Member

closed due to inactivity.
related openxla/xla#24

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants