Skip to content

jax.numpy array indexing has different out-of-bounds behavior to numpy #278

@jonasrauber

Description

@jonasrauber
import jax
import jax.numpy as np

x = np.arange(10)
x = jax.device_put(x)
print(x[[13]])

This prints [3], but it should actually throw an out of bounds error like the original NumPy.

P.S.: why does np.arange return a host array? Is this intended behavior or shouldn't it rather behave like np.array and return a device array?

Metadata

Metadata

Assignees

Labels

NVIDIA GPUIssues specific to NVIDIA GPUsP2 (eventual)This ought to be addressed, but has no schedule at the moment. (Assignee optional)documentation

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions