Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable analyzing nested input- and output-dicts #212

Merged
merged 16 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install mypy pytest pytest-cov
pip install torch==${{ matrix.pytorch-version }} torchvision
pip install transformers
- name: mypy
if: ${{ matrix.pytorch-version == '1.13' }}
run: |
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pylint
pytest
pytest-cov
pre-commit
transformers
17 changes: 17 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,23 @@ def forward(
return x


class HighlyNestedDictModel(nn.Module):
"""Model that returns a highly nested dict."""

def __init__(self) -> None:
super().__init__()
self.lin1 = nn.Linear(10, 10)
self.lin2 = nn.Linear(10, 10)

def forward(
self, x: torch.Tensor
) -> dict[str, tuple[dict[str, list[torch.Tensor]]]]:
x = self.lin1(x)
x = self.lin2(x)
x = F.softmax(x)
return {"foo": ({"bar": [x]},)}


class NamedTuple(nn.Module):
"""Model that takes in a NamedTuple as input."""

Expand Down
25 changes: 25 additions & 0 deletions tests/fixtures/torchversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch


def torchversion_at_least(version: str) -> bool:
"""
Returns True if the installed version of torch is at least the given version.
For example, if "1.13.1" is installed, `torchversion_at_least("1.8")` would
yield `True`, but if "1.7.1" is installed, torchversion_at_least("1.8")` would
yield `False`.
"""
version_installed = torch.__version__.split(".")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the version utils from packaging instead of reimplementing them:

https://stackoverflow.com/questions/11887762/how-do-i-compare-version-numbers-in-python

This also makes my test.yml file a lot simpler, thanks for the suggestion

version_given = version.split(".")

for num_installed, num_given in zip(version_installed, version_given):
if int(num_given) < int(num_installed):
return True
if int(num_given) > int(num_installed):
return False

if len(version_given) > len(
version_installed
): # e.g. "1.7.1" installed, "1.7" given
return False

return True
38 changes: 38 additions & 0 deletions tests/test_output/bert.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
====================================================================================================
Layer (type:depth-idx) Output Shape Param #
====================================================================================================
BertModel [2, 768] --
├─BertEmbeddings: 1-1 [2, 512, 768] --
│ └─Embedding: 2-1 [2, 512, 768] 23,440,896
│ └─Embedding: 2-2 [2, 512, 768] 1,536
│ └─Embedding: 2-3 [1, 512, 768] 393,216
│ └─LayerNorm: 2-4 [2, 512, 768] 1,536
│ └─Dropout: 2-5 [2, 512, 768] --
├─BertEncoder: 1-2 [2, 512, 768] --
│ └─ModuleList: 2-6 -- --
│ │ └─BertLayer: 3-1 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-2 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-3 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-4 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-5 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-6 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-7 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-8 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-9 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-10 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-11 [2, 512, 768] 7,087,872
│ │ └─BertLayer: 3-12 [2, 512, 768] 7,087,872
├─BertPooler: 1-3 [2, 768] --
│ └─Linear: 2-7 [2, 768] 590,592
│ └─Tanh: 2-8 [2, 768] --
====================================================================================================
Total params: 109,482,240
Trainable params: 109,482,240
Non-trainable params: 0
Total mult-adds (M): 218.57
====================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 852.50
Params size (MB): 437.93
Estimated Total Size (MB): 1290.45
====================================================================================================
46 changes: 46 additions & 0 deletions tests/test_output/flan_t5_small.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
==============================================================================================================
Layer (type:depth-idx) Output Shape Param #
==============================================================================================================
T5ForConditionalGeneration [2, 100, 512] --
├─T5Stack: 1-1 [2, 100, 512] 35,332,800
├─T5Stack: 1-2 -- (recursive)
│ └─Embedding: 2-1 [2, 100, 512] 16,449,536
├─T5Stack: 1-3 -- (recursive)
│ └─Dropout: 2-2 [2, 100, 512] --
│ └─ModuleList: 2-3 -- --
│ │ └─T5Block: 3-1 [2, 100, 512] 2,360,512
│ │ └─T5Block: 3-2 [2, 100, 512] 2,360,320
│ │ └─T5Block: 3-3 [2, 100, 512] 2,360,320
│ │ └─T5Block: 3-4 [2, 100, 512] 2,360,320
│ │ └─T5Block: 3-5 [2, 100, 512] 2,360,320
│ │ └─T5Block: 3-6 [2, 100, 512] 2,360,320
│ │ └─T5Block: 3-7 [2, 100, 512] 2,360,320
│ │ └─T5Block: 3-8 [2, 100, 512] 2,360,320
│ └─T5LayerNorm: 2-4 [2, 100, 512] 512
│ └─Dropout: 2-5 [2, 100, 512] --
├─T5Stack: 1-4 [2, 6, 100, 64] 16,449,536
│ └─Embedding: 2-6 [2, 100, 512] (recursive)
│ └─Dropout: 2-7 [2, 100, 512] --
│ └─ModuleList: 2-8 -- --
│ │ └─T5Block: 3-9 [2, 100, 512] 3,147,456
│ │ └─T5Block: 3-10 [2, 100, 512] 3,147,264
│ │ └─T5Block: 3-11 [2, 100, 512] 3,147,264
│ │ └─T5Block: 3-12 [2, 100, 512] 3,147,264
│ │ └─T5Block: 3-13 [2, 100, 512] 3,147,264
│ │ └─T5Block: 3-14 [2, 100, 512] 3,147,264
│ │ └─T5Block: 3-15 [2, 100, 512] 3,147,264
│ │ └─T5Block: 3-16 [2, 100, 512] 3,147,264
│ └─T5LayerNorm: 2-9 [2, 100, 512] 512
│ └─Dropout: 2-10 [2, 100, 512] --
├─Linear: 1-5 [2, 100, 32128] 16,449,536
==============================================================================================================
Total params: 128,743,488
Trainable params: 128,743,488
Non-trainable params: 0
Total mult-adds (M): 186.86
==============================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 217.84
Params size (MB): 307.84
Estimated Total Size (MB): 525.69
==============================================================================================================
17 changes: 17 additions & 0 deletions tests/test_output/highly_nested_dict_model.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
HighlyNestedDictModel [10] --
├─Linear: 1-1 [10] 110
├─Linear: 1-2 [10] 110
==========================================================================================
Total params: 220
Trainable params: 220
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
6 changes: 6 additions & 0 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DictParameter,
EmptyModule,
FakePrunedLayerModel,
HighlyNestedDictModel,
InsideModel,
LinearModel,
LSTMNet,
Expand Down Expand Up @@ -344,6 +345,11 @@ def test_module_dict() -> None:
)


def test_highly_nested_dict_model() -> None:
model = HighlyNestedDictModel()
summary(model, input_data=torch.ones(10))


def test_model_with_args() -> None:
summary(RecursiveNet(), input_size=(1, 64, 28, 28), args1="args1", args2="args2")

Expand Down
34 changes: 34 additions & 0 deletions tests/torchinfo_xl_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import pytest
import torch
import torchvision # type: ignore[import]
from transformers import ( # type: ignore[import]
AutoModelForSeq2SeqLM,
BertConfig,
BertModel,
)

from tests.fixtures.genotype import GenotypeNetwork # type: ignore[attr-defined]
from tests.fixtures.tmva_net import TMVANet # type: ignore[attr-defined]
from tests.fixtures.torchversion import torchversion_at_least
from torchinfo import summary
from torchinfo.enums import ColumnSettings

Expand Down Expand Up @@ -143,3 +149,31 @@ def test_google() -> None:
# Check googlenet in training mode since InceptionAux layers are used in
# forward-prop in train mode but not in eval mode.
summary(google_net, (1, 3, 112, 112), depth=7, mode="train")


@pytest.mark.skipif(
not torchversion_at_least("1.8"),
reason="FlanT5Small only works for PyTorch v1.8 and above",
)
def test_flan_t5_small() -> None:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
inputs = {
"input_ids": torch.zeros(2, 100).long(),
"attention_mask": torch.zeros(2, 100).long(),
"labels": torch.zeros(2, 100).long(),
}
summary(model, input_data=inputs)


@pytest.mark.skipif(
not torchversion_at_least("1.8"),
reason="BertModel only works for PyTorch v1.8 and above",
)
def test_bert() -> None:
model = BertModel(BertConfig())
summary(
model,
input_size=[(2, 512), (2, 512), (2, 512)],
dtypes=[torch.int, torch.int, torch.int],
device="cpu",
)
18 changes: 17 additions & 1 deletion torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def nested_list_size(inputs: Sequence[Any]) -> tuple[list[int], int]:
return nested_list_size(inputs[0])
return [], 0

def extract_tensor(inputs: torch.Tensor | Sequence[Any]) -> torch.Tensor:
"""Extracts tensor from sequence."""
if isinstance(inputs, torch.Tensor):
return inputs
if hasattr(inputs, "tensors"):
return extract_tensor(list(inputs.tensors))
if not hasattr(inputs, "__getitem__") or not inputs:
return torch.Tensor([])
if isinstance(inputs, dict):
return extract_tensor(list(inputs.values()))
if isinstance(inputs, (list, tuple)):
return extract_tensor(inputs[0])
return torch.Tensor([])

if inputs is None:
size, elem_bytes = [], 0

Expand All @@ -127,9 +141,11 @@ def nested_list_size(inputs: Sequence[Any]) -> tuple[list[int], int]:
elif isinstance(inputs, dict):
# TODO avoid overwriting the previous size every time
size = []
elem_bytes = list(inputs.values())[0].element_size()
elem_bytes = 0
for _, output in inputs.items():
output = extract_tensor(output)
size = list(output.size())
elem_bytes = output.element_size()
if batch_dim is not None:
size = [size[:batch_dim] + [1] + size[batch_dim + 1 :]]

Expand Down
18 changes: 14 additions & 4 deletions torchinfo/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,24 @@ def __init__(
self.total_output_bytes += layer_info.output_bytes * 2
if layer_info.is_recursive:
continue
self.total_params += layer_info.num_params
self.total_params += (
layer_info.num_params if layer_info.num_params > 0 else 0
)
self.total_param_bytes += layer_info.param_bytes
self.trainable_params += layer_info.trainable_params
self.trainable_params += (
layer_info.trainable_params
if layer_info.trainable_params > 0
else 0
)
else:
if layer_info.is_recursive:
continue
self.total_params += layer_info.leftover_params()
self.trainable_params += layer_info.leftover_trainable_params()
leftover_params = layer_info.leftover_params()
leftover_trainable_params = layer_info.leftover_trainable_params()
self.total_params += leftover_params if leftover_params > 0 else 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is suspicious, let's see if we can figure out why params can ever be negative. I'll take a look too

self.trainable_params += (
leftover_trainable_params if leftover_trainable_params > 0 else 0
)
self.formatting.set_layer_name_width(summary_list)

def __repr__(self) -> str:
Expand Down