-
Notifications
You must be signed in to change notification settings - Fork 4
/
summary.py
62 lines (57 loc) · 2.19 KB
/
summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
import torchinfo
from ptflops import get_model_complexity_info
from libs.consts import IMAGENET, CIFAR10, CIFAR100
from libs.datasets import ImageNetGetter, CIFAR10Getter, CIFAR100Getter
from libs.models import RaftMLP
@hydra.main(config_path="configs", config_name="config")
def run_summary(params: DictConfig) -> None:
print(OmegaConf.to_yaml(params))
OmegaConf.set_struct(params, True)
if params.settings.dataset_name == IMAGENET:
dg = ImageNetGetter(
color_jitter=params.settings.color_jitter,
cutout_p=params.settings.cutout_p,
)
elif params.settings.dataset_name == CIFAR10:
dg = CIFAR10Getter(
color_jitter=params.settings.color_jitter,
cutout_p=params.settings.cutout_p,
)
elif params.settings.dataset_name == CIFAR100:
dg = CIFAR100Getter(
color_jitter=params.settings.color_jitter,
cutout_p=params.settings.cutout_p,
)
else:
raise ValueError("Invalid dataset name")
model = RaftMLP(
layers=params.settings.layers,
in_channels=dg.channels,
image_size=dg.image_size,
num_classes=dg.num_classes,
token_expansion_factor=params.settings.token_expansion_factor,
channel_expansion_factor=params.settings.channel_expansion_factor,
dropout=params.settings.dropout,
token_mixing_type=params.settings.token_mixing_type,
embedding_type=params.settings.embedding_type,
shortcut=params.settings.shortcut,
gap=params.settings.gap,
drop_path_rate=params.settings.drop_path_rate,
)
input_size = (dg.channels, dg.image_size, dg.image_size)
torchinfo.summary(model, input_size=(1, *input_size))
with torch.cuda.device(0):
macs, params = get_model_complexity_info(
model,
input_size,
as_strings=True,
print_per_layer_stat=True,
verbose=True,
)
print("{:<30} {:<8}".format("Computational complexity: ", macs))
print("{:<30} {:<8}".format("Number of parameters: ", params))
if __name__ == "__main__":
run_summary()