You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to launch CUDA kernel: fusion_1560 with block dimensions: 160x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered.
The text was updated successfully, but these errors were encountered:
I had some out-of-bounds indexing in my code. E.g.
x = jnp.asarray([0, 1, 2])
y = x[4]
would not throw an error in jax. But it broke the gradient computation in my case. See also this JAX issue which describes that, during backprop, out of bounds indexing leads to undefined behaviour of JAX. In my case, out-of-bounds indexing occured twice: (1) the parameters of each cell was flattened before I put it into neurax (instead of keeping it as a list). (2) I had stimuli into 10 neurons, even when I had fewer than 10 neurons available.
The error only occured with jit. When I used jax.disable_jit(), the error did not occur.
But the error did also not occur when I used jit(value_and_grad(fun)). So, it only occured if the fori_loop of the simulator was compiled but the value_and_grad was called without explicitly jitting.
INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to launch CUDA kernel: fusion_1560 with block dimensions: 160x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered.
The text was updated successfully, but these errors were encountered: