-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix calculate_macs() for Linear layers. #318
Conversation
Could you add a unit test and also explain why this fixes the issue? Hardcoding |
The documentation for Currently, torchinfo only handles two-dimensional input tensors correctly, any additional dimensions are ignored: self.macs += self.output_size[0] * cur_params The change accounts for all of the leading dimensions: elif "Linear" in self.class_name:
self.macs += int(cur_params * prod(self.output_size[:-1])) You can think of See section A.2.3 N-mode Product in this reference for a detailed explanation. |
Fix MACs in lst.out and lstm_half.out.
9a483b9
to
c2da67d
Compare
I corrected the output files for the LSTM tests, but the flan_t5 tests are still broken. I am not familiar with that model,so I would have to learn how to calculate the number of operations correctly. |
The way Also, any unrecognized layer is assumed to be a Linear layer with 2D input, judging by the code in the |
I added a unit test for torch.nn.Linear that uses a 3D input tensor. |
Thanks for the explanation. It seems very reasonable that this error would cause other model outputs to change, so feel free to update those as suggested by the test errors. I think this specific calculation has a lot of room for improvement, since it was only hardcoded with a few layer types to begin with. It would be great to enhance the |
It looks like torch.nn.Embedding layers have the same bug that |
MACs increased from 280.27M to 18.25G because of the Linear layer fix.
@TylerYep, I followed your guidance and changed the ground truth in Let me know if you prefer the three commits in this pull request to be squashed into a single commit. |
Looks good. I'll happily accept more PRs expanding this functionality. Thank you for your contributions! |
Fixes #77.