```python3 import jax import jax.numpy as np x = np.arange(10) x = jax.device_put(x) print(x[[13]]) ``` This prints `[3]`, but it should actually throw an _out of bounds_ error like the original NumPy. P.S.: why does `np.arange` return a host array? Is this intended behavior or shouldn't it rather behave like np.array and return a device array?