Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Future reference: error at grad: INTERNAL: Failed to execute XLA Runtime executable #53

Closed
michaeldeistler opened this issue May 5, 2023 · 2 comments

Comments

@michaeldeistler
Copy link
Contributor

INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to launch CUDA kernel: fusion_1560 with block dimensions: 160x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered.

@michaeldeistler
Copy link
Contributor Author

michaeldeistler commented May 6, 2023

Solution:

  1. I had some out-of-bounds indexing in my code. E.g.
x = jnp.asarray([0, 1, 2])
y = x[4]

would not throw an error in jax. But it broke the gradient computation in my case. See also this JAX issue which describes that, during backprop, out of bounds indexing leads to undefined behaviour of JAX. In my case, out-of-bounds indexing occured twice: (1) the parameters of each cell was flattened before I put it into neurax (instead of keeping it as a list). (2) I had stimuli into 10 neurons, even when I had fewer than 10 neurons available.

  1. The error only occured with jit. When I used jax.disable_jit(), the error did not occur.

  2. But the error did also not occur when I used jit(value_and_grad(fun)). So, it only occured if the fori_loop of the simulator was compiled but the value_and_grad was called without explicitly jitting.

@michaeldeistler
Copy link
Contributor Author

Full trace of many runs I recorded:

### Worked: 
- parallel_elim = False  
- grad_ode = jit(value_and_grad(ode))  
- sim_inds = np.arange(6)  
- "stone"  
- Connection(0, 0, 0.0, 0, 1, 0.0)  

### Did not work (kernel died)
- parallel_elim = False  
- grad_ode = value_and_grad(ode)  
- sim_inds = np.arange(6)  
- "stone"  
- Connection(0, 0, 0.0, 0, 1, 0.0)  

### Worked
- parallel_elim = True  
- grad_ode = jit(value_and_grad(ode))  
- sim_inds = np.arange(6)  
- "stone"  
- Connection(0, 0, 0.0, 0, 1, 0.0)  

### Worked
- parallel_elim = True  
- grad_ode = jit(value_and_grad(ode))  
- sim_inds = np.arange(6)  
- "thomas"
- Connection(0, 0, 0.0, 0, 1, 0.0)  

### Worked
- parallel_elim = True  
- grad_ode = jit(value_and_grad(ode))  
- sim_inds = np.arange(15)  
- "thomas"
- Connection(0, 0, 0.0, 0, 1, 0.0)  

--- until here, I accidentally set `syn_params = [jnp.asarray([0.02] * len(conn)) for conn in conns]` instead of 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant