Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix calculate_macs() for Linear layers. #318

Merged
merged 3 commits into from
Nov 5, 2024

Conversation

andravin
Copy link
Contributor

Fixes #77.

@TylerYep
Copy link
Owner

TylerYep commented Nov 1, 2024

Could you add a unit test and also explain why this fixes the issue?

Hardcoding Linear is not ideal, but if there is something special about Linear layers that requires this fix I would like to understand it more completely.

@andravin
Copy link
Contributor Author

andravin commented Nov 1, 2024

The documentation for torch.nn.Linear says the input tensor can have "any number of dimensions."

Currently, torchinfo only handles two-dimensional input tensors correctly, any additional dimensions are ignored:

                    self.macs += self.output_size[0] * cur_params

The change accounts for all of the leading dimensions:

                elif "Linear" in self.class_name:
                    self.macs += int(cur_params * prod(self.output_size[:-1]))

You can think of torch.nn.Linear as a tensor-matrix multiplication or n-mode product, with the last dimension of the tensor as the "mode." This is equivalent to matrix multiplication if we first unfold all of the leading dimensions of the tensor into a single dimension. The size of the unfolded dimension is prod(self.output_size[:-1]).

See section A.2.3 N-mode Product in this reference for a detailed explanation.

Fix MACs in lst.out and lstm_half.out.
@andravin andravin force-pushed the fix-linear-layer-macs branch from 9a483b9 to c2da67d Compare November 1, 2024 21:29
@andravin
Copy link
Contributor Author

andravin commented Nov 1, 2024

I corrected the output files for the LSTM tests, but the flan_t5 tests are still broken. I am not familiar with that model,so I would have to learn how to calculate the number of operations correctly.

@andravin
Copy link
Contributor Author

andravin commented Nov 1, 2024

The way LayerInfo.calculate_macs just checks a partial match on the class-name for Conv and now Linear and ignores the module name seems dicey. Exact matches against fully qualified <module>.<class> names would seem more robust.

Also, any unrecognized layer is assumed to be a Linear layer with 2D input, judging by the code in the else block. That also seems error-prone. It might be better to throw an exception for unsupported layers.

@andravin
Copy link
Contributor Author

andravin commented Nov 2, 2024

I added a unit test for torch.nn.Linear that uses a 3D input tensor.

@TylerYep
Copy link
Owner

TylerYep commented Nov 2, 2024

Thanks for the explanation. It seems very reasonable that this error would cause other model outputs to change, so feel free to update those as suggested by the test errors.

I think this specific calculation has a lot of room for improvement, since it was only hardcoded with a few layer types to begin with. It would be great to enhance the <module>.<class> exact match code as well, but feel free to tackle these things in a separate PR.

@andravin
Copy link
Contributor Author

andravin commented Nov 2, 2024

It looks like torch.nn.Embedding layers have the same bug that torch.nn.Linear layers did.

 MACs increased from 280.27M to 18.25G because of the Linear layer fix.
@andravin
Copy link
Contributor Author

andravin commented Nov 3, 2024

@TylerYep, I followed your guidance and changed the ground truth in flan_t5_small.out to equal the output of the unit tests with the linear layer MACs fix.

Let me know if you prefer the three commits in this pull request to be squashed into a single commit.

@TylerYep
Copy link
Owner

TylerYep commented Nov 5, 2024

Looks good. I'll happily accept more PRs expanding this functionality. Thank you for your contributions!

@TylerYep TylerYep merged commit 38ab72b into TylerYep:main Nov 5, 2024
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error in computing Linear Layer Multiply adds
3 participants