-
Notifications
You must be signed in to change notification settings - Fork 358
[BUG] FFT - unhandled instruction error #713
Comments
Unfortunately, auto parallelization of FFT and control flow is not supported. What's your use case? |
Hi, Thanks for the reply. I hope you don't mind, but I found a way to use a rough FFT with Alpa and posted code with instructions here, in case someone else wants to use FFT with Alpa. It's just some functions that compute the FFT using Jax, without calling the built-in function. It seems to be working fine for training a model, but when I try compiling with long sequences (16k+ length), it seems to compile indefinitely. Is there a way to manually mark a function to not be sliced? Maybe marking the FFT function to stay on one device would allow the program to compile, but It is working just fine with a length of 8,192. I was using it for FFT Convolutions. On long sequences, this is much faster than applying a convolution normally. cuDNN should automatically detect this in certain situations and compile differently, but I was never able to get it to finish compiling when I let it try a long sequence with a normal convolution. Thanks |
Good to hear you make it work. The code looks interesting. |
closed due to inactivity. |
Please describe the bug
If parallelize shard parallel encounters Jax's FFT functions, the program crashes with an unhandled instruction error.
Please describe the expected behavior
The function runs FFT on the GPU
System information and environment
To Reproduce
Steps to reproduce the behavior:
Screenshots
If applicable, add screenshots to help explain your problem.
F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc:620] Unhandled instruction: %fft.171 = c64[128,128,4096]{2,1,0} fft(c64[128,128,4096]{1,2,0} %transpose.78), fft_type=FFT, fft_length={4096}, metadata={op_name="parallelize(update_fn_shard_parallel)/jit(main)/jit(update_fn)/jvp(alpaBug)/VmapModel0/...)/jit(fft)/jit(fft)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4096,)]" source_file="/user/sblouir/alpaBug/ft_bug.py" source_line=393}
starter.sh: line 30: 1591031 Aborted (core dumped) $@
Code snippet to reproduce the problem
It's really just these FFT functions anywhere in the parallelize path that cause the issue. I can write a real code snippet, if needed.
def apply_fn(x):
....return jnp.fft.fft(x)
@alpa.parallelize
def train_step(model_state, batch, apply_fn):
....def loss_func(params):
........out = apply_fn(params, batch["x"])
........return jnp.mean((out - batch["y"]) ** 2)
....grads = grad(loss_func)(model_state.params)
....new_model_state = model_state.apply_gradient(grads)
....return new_model_state
model_state = create_train_state()
for batch in data_loader:
....model_state = train_step(model_state, batch)
Additional information
jax.lax.cond crashes similarly. Not sure if this FFT implementation has this underneath.
Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.
The text was updated successfully, but these errors were encountered: