how to use vmap without collating the output #12345
-
Hi, I'm sorry if this has been asked before, couldn't find an answer. I have a function that creates an object in the return, I want vmap to return an array of dicts and not a dict of arrays. This code shows what I want and what I got. Any suggestions? from jax import vmap
from jax.random import uniform, split, PRNGKey
def sample_dict(key:PRNGKey):
key, subkey = split(key)
x = uniform(key, (3,))
y = uniform(subkey, (3,))
return dict(x=x, y=y)
v_sample_dict = vmap(sample_dict)
out = v_sample_dict(split(PRNGKey(0), 3))
print("what I got")
print(out)
print("what I want")
print([dict(x=x,y=y) for x,y in zip(out['x'],out['y'])])
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Great question! There's no way in JAX to represent an array of dicts; the only element types which are allowed are numeric numpy dtypes like Another way to think about it is that you have to work in structure-of-arrays form, and you can't have arrays-of-structures. We've batted around some ideas to let you write arrays-of-structures directly, but we haven't made it a priority. I'd be interested to hear about your use case, and whether writing structures-of-arrays is really painful in it. |
Beta Was this translation helpful? Give feedback.
Great question!
There's no way in JAX to represent an array of dicts; the only element types which are allowed are numeric numpy dtypes like
int32
,float32
, etc.Another way to think about it is that you have to work in structure-of-arrays form, and you can't have arrays-of-structures.
We've batted around some ideas to let you write arrays-of-structures directly, but we haven't made it a priority. I'd be interested to hear about your use case, and whether writing structures-of-arrays is really painful in it.