Skip to content

how to use vmap without collating the output #12345

Answered by mattjj
aldopareja asked this question in Q&A
Discussion options

You must be logged in to vote

Great question!

There's no way in JAX to represent an array of dicts; the only element types which are allowed are numeric numpy dtypes like int32, float32, etc.

Another way to think about it is that you have to work in structure-of-arrays form, and you can't have arrays-of-structures.

We've batted around some ideas to let you write arrays-of-structures directly, but we haven't made it a priority. I'd be interested to hear about your use case, and whether writing structures-of-arrays is really painful in it.

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@mattjj
Comment options

@aldopareja
Comment options

@mattjj
Comment options

@mattjj
Comment options

@froystig
Comment options

Answer selected by aldopareja
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants