Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimension changing #264

Open
samuelstevens opened this issue Nov 5, 2024 · 2 comments
Open

Dimension changing #264

samuelstevens opened this issue Nov 5, 2024 · 2 comments
Labels
question User queries

Comments

@samuelstevens
Copy link

import beartype
import torch

from jaxtyping import Float, jaxtyped
from torch import Tensor


@beartype.beartype
class Identity(torch.nn.Module):
    @jaxtyped(typechecker=beartype.beartype)
    def forward(self, x: Float[Tensor, "batch d_model"]):
        return x


@jaxtyped(typechecker=beartype.beartype)
def main():
    model = Identity()

    x = torch.ones((4, 1), dtype=torch.float)
    model(x)

    x = torch.ones((6, 1), dtype=torch.float)
    model(x)  # <- Throws an error about "batch" not matching.


if __name__ == "__main__":
    main()

Not sure why this is happening. But this second call to model.forward will complain because "batch" has changed between calls. How can I fix this?

@patrick-kidger
Copy link
Owner

I think the additional beartype decorator you have on the class is responsible. If I remove this then the code passes successfuly.

@patrick-kidger patrick-kidger added the question User queries label Nov 6, 2024
@samuelstevens
Copy link
Author

Yeah, and it also passes if I replace it with jaxtyped(typechecker=beartype.beartype). Not sure why this is the case, or how to think about this so I can avoid it in the future. If you have any ideas, that would be appreciated. If not, feel free to close this issue. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants