You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was surprised to discover that vmap does not always convert NumPy arrays into JAX arrays. Instead, it sometimes (rarely) will return NumPy arrays when given NumPy arrays as inputs:
The exceptions appear to be cases where the wrapped function is evaluated to the "identity", as least as far as JAXprs are concerned:
print(jax.make_jaxpr(jax.vmap(jnp.asarray))(jax_array))
# { lambda ; a:i32[3]. let in (a,) }
However, jnp.asarray() is also explicitly converting inputs into a JAX arrays, so I found this doubly surprsing.
My real-use is a custom pytree type where the constructor (but not tree_unflatten) always converts inputs into JAX arrays, which otherwise seems perfectly well behaved and is one of the suggested patterns in the JAX docs. vmap returns an otherwise impossible to create value when applied to this "identity function":
Description
I was surprised to discover that
vmap
does not always convert NumPy arrays into JAX arrays. Instead, it sometimes (rarely) will return NumPy arrays when given NumPy arrays as inputs:The exceptions appear to be cases where the wrapped function is evaluated to the "identity", as least as far as JAXprs are concerned:
However,
jnp.asarray()
is also explicitly converting inputs into a JAX arrays, so I found this doubly surprsing.My real-use is a custom pytree type where the constructor (but not
tree_unflatten
) always converts inputs into JAX arrays, which otherwise seems perfectly well behaved and is one of the suggested patterns in the JAX docs.vmap
returns an otherwise impossible to create value when applied to this "identity function":System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 1.26.4
python: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='a953559aba37', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: