Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPT: Change order of operands to enable PT2 compile for inference #559

Merged
merged 2 commits into from
Aug 28, 2023

Conversation

tdoublep
Copy link
Contributor

This is a trivial change, which doesn't actually change any logic, but it very helpful to enable PyTorch 2 compile for inference using mpt-based models.

During compilation, evaluating attention_mask[:, 0].sum() != attention_mask.shape[0] seems to cause a bunch of problems because it compares the shape of the tensor to its actual contents, and leads to graph breaks. By changing the ordering of the operands in this if statement, the problem is resolved since self.training gets evaluated first and fails, thus preventing the rest from being evaluated.

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@tdoublep
Copy link
Contributor Author

Anything I can do regarding the failing checks? Not 100% sure but doesn't seem to be related to the proposed code change.

@dakinggg
Copy link
Collaborator

Its just the autoformatting, if you run pre-commit run --all-files and commit, it will fix it.

@tdoublep
Copy link
Contributor Author

@dakinggg I ran the code formatting, but looks like an approval is needed for the checks to re-run.

@dakinggg dakinggg merged commit a8c7dc4 into mosaicml:main Aug 28, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants