Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.while_loop (and probably other operations) crash when passed nonhashable cond/body functions. #13554

Open
patrick-kidger opened this issue Dec 8, 2022 · 0 comments
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Dec 8, 2022

import jax.lax as lax

class Cond:
    def __call__(self, carry):
        return carry < 4

    __hash__ = None

class Body:
    def __call__(self, carry):
        return carry + 1

    __hash__ = None

def cond(carry):
    return carry < 4

def body(carry):
    return carry + 1

try:
    lax.while_loop(cond, Body(), 0)
except TypeError as e:
    print(e)
try:
    lax.while_loop(Cond(), body, 0)
except TypeError as e:
    print(e)

produces

unhashable type: 'Body'
unhashable type: 'Cond'

In practice Equinox tends to pass around a lot of callable PyTrees, whose hash is the hash of their leaves. If those leaves are JAX arrays then the PyTree is unhashable. (Whilst still being callable.)

Probably this is low-priority, but fixing it would be easy enough, so I'm opening this issue mostly as a reminder for myself in case I ever get time to come back and sort this out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant