-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_structure.txt
113 lines (113 loc) · 15.2 KB
/
model_structure.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
======================================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
======================================================================================================================================================
SpaceTimeUnet (SpaceTimeUnet) [1, 4, 70, 80, 64] [1, 4, 70, 80, 64] -- True
├─Sequential (to_timestep_cond) [1] [1, 256] -- True
│ └─SinusoidalPosEmb (0) [1] [1, 64] -- --
│ └─Linear (1) [1, 64] [1, 256] 16,640 True
│ └─SiLU (2) [1, 256] [1, 256] -- --
├─PseudoConv3d (conv_in) [1, 4, 70, 80, 64] [1, 64, 70, 80, 64] -- True
│ └─Conv2d (spatial_conv) [70, 4, 80, 64] [70, 64, 80, 64] 12,608 True
│ └─Conv1d (temporal_conv) [5120, 64, 70] [5120, 64, 70] 12,352 True
├─ModuleList (downs) -- -- -- True
│ └─ModuleList (0) -- -- -- True
│ │ └─ResnetBlock (0) [1, 64, 70, 80, 64] [1, 64, 70, 80, 64] 131,712 True
│ │ └─ModuleList (1) -- -- 197,632 True
│ │ └─Downsample (3) [1, 64, 70, 80, 64] [1, 64, 70, 40, 32] 16,384 True
│ │ └─AttentionBlock (4) [1, 64, 70, 40, 32] [1, 64, 70, 40, 32] 160,704 True
│ └─ModuleList (1) -- -- -- True
│ │ └─ResnetBlock (0) [1, 64, 70, 40, 32] [1, 128, 70, 40, 32] 394,624 True
│ │ └─ModuleList (1) -- -- 788,480 True
│ │ └─Downsample (3) [1, 128, 70, 40, 32] [1, 128, 70, 20, 16] 65,536 True
│ │ └─AttentionBlock (4) [1, 128, 70, 20, 16] [1, 128, 70, 20, 16] 444,288 True
│ └─ModuleList (2) -- -- -- True
│ │ └─ResnetBlock (0) [1, 128, 70, 20, 16] [1, 256, 70, 20, 16] 1,444,608 True
│ │ └─ModuleList (1) -- -- 3,149,824 True
│ │ └─Downsample (3) [1, 256, 70, 20, 16] [1, 256, 70, 10, 8] 262,144 True
│ │ └─AttentionBlock (4) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,380,096 True
│ └─ModuleList (3) -- -- -- True
│ │ └─ResnetBlock (0) [1, 256, 70, 10, 8] [1, 512, 70, 10, 8] 5,510,656 True
│ │ └─ModuleList (1) -- -- 12,591,104 True
│ │ └─SpatioTemporalAttention (2) [1, 512, 70, 10, 8] [1, 512, 70, 10, 8] 4,334,181 True
│ │ └─Downsample (3) [1, 512, 70, 10, 8] [1, 512, 35, 5, 4] 1,572,864 True
│ │ └─AttentionBlock (4) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 4,726,272 True
├─ResnetBlock (mid_block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ └─Sequential (timestep_mlp) [1, 256] [1, 1024] -- True
│ │ └─SiLU (0) [1, 256] [1, 256] -- --
│ │ └─Linear (1) [1, 256] [1, 1024] 263,168 True
│ └─Block (block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
│ └─Block (block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
│ └─Identity (res_conv) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
├─SpatioTemporalAttention (mid_attn) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ └─ContinuousPositionBias (spatial_rel_pos_bias) -- [8, 20, 20] -- True
│ │ └─ModuleList (net) -- -- 68,616 True
│ └─Attention (spatial_attn) [35, 20, 512] [35, 20, 512] -- True
│ │ └─LayerNorm (norm) [35, 20, 512] [35, 20, 512] 1,024 True
│ │ └─Linear (to_q) [35, 20, 512] [35, 20, 512] 262,144 True
│ │ └─Linear (to_kv) [35, 20, 512] [35, 20, 1024] 524,288 True
│ │ └─Linear (to_out) [35, 20, 512] [35, 20, 512] 262,144 True
│ └─ContinuousPositionBias (temporal_rel_pos_bias) -- [8, 35, 35] -- True
│ │ └─ModuleList (net) -- -- 68,360 True
│ └─Attention (temporal_attn) [20, 35, 512] [20, 35, 512] -- True
│ │ └─LayerNorm (norm) [20, 35, 512] [20, 35, 512] 1,024 True
│ │ └─Linear (to_q) [20, 35, 512] [20, 35, 512] 262,144 True
│ │ └─Linear (to_kv) [20, 35, 512] [20, 35, 1024] 524,288 True
│ │ └─Linear (to_out) [20, 35, 512] [20, 35, 512] 262,144 True
│ └─FeedForward (ff) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ │ └─Sequential (proj_in) [1, 512, 35, 5, 4] [1, 1365, 35, 5, 4] 1,397,760 True
│ │ └─Sequential (proj_out) [1, 1365, 35, 5, 4] [1, 512, 35, 5, 4] 700,245 True
├─ResnetBlock (mid_block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ └─Sequential (timestep_mlp) [1, 256] [1, 1024] -- True
│ │ └─SiLU (0) [1, 256] [1, 256] -- --
│ │ └─Linear (1) [1, 256] [1, 1024] 263,168 True
│ └─Block (block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
│ └─Block (block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
│ └─Identity (res_conv) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
├─ModuleList (ups) -- -- -- True
│ └─ModuleList (3) -- -- -- True
│ │ └─Upsample (3) [1, 512, 35, 5, 4] [1, 512, 70, 10, 8] 1,575,936 True
│ │ └─ResnetBlock (0) [1, 1024, 70, 10, 8] [1, 256, 70, 10, 8] 3,738,368 True
│ │ └─ModuleList (1) -- -- 4,526,336 True
│ │ └─SpatioTemporalAttention (2) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,609,786 True
│ │ └─AttentionBlock (4) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,380,096 True
│ └─ModuleList (2) -- -- -- True
│ │ └─Upsample (3) [1, 256, 70, 10, 8] [1, 256, 70, 20, 16] 263,168 True
│ │ └─ResnetBlock (0) [1, 512, 70, 20, 16] [1, 128, 70, 20, 16] 968,064 True
│ │ └─ModuleList (1) -- -- 1,132,672 True
│ │ └─AttentionBlock (4) [1, 128, 70, 20, 16] [1, 128, 70, 20, 16] 444,288 True
│ └─ModuleList (1) -- -- -- True
│ │ └─Upsample (3) [1, 128, 70, 20, 16] [1, 128, 70, 40, 32] 66,048 True
│ │ └─ResnetBlock (0) [1, 256, 70, 40, 32] [1, 64, 70, 40, 32] 258,752 True
│ │ └─ModuleList (1) -- -- 283,712 True
│ │ └─AttentionBlock (4) [1, 64, 70, 40, 32] [1, 64, 70, 40, 32] 160,704 True
│ └─ModuleList (0) -- -- -- True
│ │ └─Upsample (3) [1, 64, 70, 40, 32] [1, 64, 70, 80, 64] 16,640 True
│ │ └─ResnetBlock (0) [1, 128, 70, 80, 64] [1, 64, 70, 80, 64] 176,832 True
│ │ └─ModuleList (1) -- -- 242,752 True
│ │ └─AttentionBlock (4) [1, 64, 70, 80, 64] [1, 64, 70, 80, 64] 160,704 True
├─PseudoConv3d (conv_out) [1, 64, 70, 80, 64] [1, 4, 70, 80, 64] -- True
│ └─Conv2d (spatial_conv) [70, 64, 80, 64] [70, 4, 80, 64] 2,308 True
│ └─Conv1d (temporal_conv) [5120, 4, 70] [5120, 4, 70] 52 True
======================================================================================================================================================
Total params: 71,671,548
Trainable params: 71,671,548
Non-trainable params: 0
Total mult-adds (G): 732.56
======================================================================================================================================================
Input size (MB): 5.89
Forward/backward pass size (MB): 18136.46
Params size (MB): 286.69
Estimated Total Size (MB): 18429.04
======================================================================================================================================================