1616"""A space-time Transformer with Cuboid Attention"""
1717
1818
19- class InitialEncoder (paddle . nn .Layer ):
19+ class InitialEncoder (nn .Layer ):
2020 def __init__ (
2121 self ,
2222 dim ,
@@ -38,39 +38,35 @@ def __init__(
3838 for i in range (num_conv_layers ):
3939 if i == 0 :
4040 conv_block .append (
41- paddle . nn .Conv2D (
41+ nn .Conv2D (
4242 kernel_size = (3 , 3 ),
4343 padding = (1 , 1 ),
4444 in_channels = dim ,
4545 out_channels = out_dim ,
4646 )
4747 )
48- conv_block .append (
49- paddle .nn .GroupNorm (num_groups = 16 , num_channels = out_dim )
50- )
48+ conv_block .append (nn .GroupNorm (num_groups = 16 , num_channels = out_dim ))
5149 conv_block .append (
5250 act_mod .get_activation (activation )
5351 if activation != "leaky_relu"
5452 else nn .LeakyReLU (NEGATIVE_SLOPE )
5553 )
5654 else :
5755 conv_block .append (
58- paddle . nn .Conv2D (
56+ nn .Conv2D (
5957 kernel_size = (3 , 3 ),
6058 padding = (1 , 1 ),
6159 in_channels = out_dim ,
6260 out_channels = out_dim ,
6361 )
6462 )
65- conv_block .append (
66- paddle .nn .GroupNorm (num_groups = 16 , num_channels = out_dim )
67- )
63+ conv_block .append (nn .GroupNorm (num_groups = 16 , num_channels = out_dim ))
6864 conv_block .append (
6965 act_mod .get_activation (activation )
7066 if activation != "leaky_relu"
7167 else nn .LeakyReLU (NEGATIVE_SLOPE )
7268 )
73- self .conv_block = paddle . nn .Sequential (* conv_block )
69+ self .conv_block = nn .Sequential (* conv_block )
7470 if isinstance (downsample_scale , int ):
7571 patch_merge_downsample = (1 , downsample_scale , downsample_scale )
7672 elif len (downsample_scale ) == 2 :
@@ -121,7 +117,7 @@ def forward(self, x):
121117 return x
122118
123119
124- class FinalDecoder (paddle . nn .Layer ):
120+ class FinalDecoder (nn .Layer ):
125121 def __init__ (
126122 self ,
127123 target_thw : Tuple [int , ...],
@@ -142,20 +138,20 @@ def __init__(
142138 conv_block = []
143139 for i in range (num_conv_layers ):
144140 conv_block .append (
145- paddle . nn .Conv2D (
141+ nn .Conv2D (
146142 kernel_size = (3 , 3 ),
147143 padding = (1 , 1 ),
148144 in_channels = dim ,
149145 out_channels = dim ,
150146 )
151147 )
152- conv_block .append (paddle . nn .GroupNorm (num_groups = 16 , num_channels = dim ))
148+ conv_block .append (nn .GroupNorm (num_groups = 16 , num_channels = dim ))
153149 conv_block .append (
154150 act_mod .get_activation (activation )
155151 if activation != "leaky_relu"
156152 else nn .LeakyReLU (NEGATIVE_SLOPE )
157153 )
158- self .conv_block = paddle . nn .Sequential (* conv_block )
154+ self .conv_block = nn .Sequential (* conv_block )
159155 self .upsample = cuboid_decoder .Upsample3DLayer (
160156 dim = dim ,
161157 out_dim = dim ,
@@ -196,7 +192,7 @@ def forward(self, x):
196192 return x
197193
198194
199- class InitialStackPatchMergingEncoder (paddle . nn .Layer ):
195+ class InitialStackPatchMergingEncoder (nn .Layer ):
200196 def __init__ (
201197 self ,
202198 num_merge : int ,
@@ -220,8 +216,8 @@ def __init__(
220216 self .downsample_scale_list = downsample_scale_list [:num_merge ]
221217 self .num_conv_per_merge_list = num_conv_per_merge_list
222218 self .num_group_list = [max (1 , out_dim // 4 ) for out_dim in self .out_dim_list ]
223- self .conv_block_list = paddle . nn .LayerList ()
224- self .patch_merge_list = paddle . nn .LayerList ()
219+ self .conv_block_list = nn .LayerList ()
220+ self .patch_merge_list = nn .LayerList ()
225221 for i in range (num_merge ):
226222 if i == 0 :
227223 in_dim = in_dim
@@ -236,15 +232,15 @@ def __init__(
236232 else :
237233 conv_in_dim = out_dim
238234 conv_block .append (
239- paddle . nn .Conv2D (
235+ nn .Conv2D (
240236 kernel_size = (3 , 3 ),
241237 padding = (1 , 1 ),
242238 in_channels = conv_in_dim ,
243239 out_channels = out_dim ,
244240 )
245241 )
246242 conv_block .append (
247- paddle . nn .GroupNorm (
243+ nn .GroupNorm (
248244 num_groups = self .num_group_list [i ], num_channels = out_dim
249245 )
250246 )
@@ -253,7 +249,7 @@ def __init__(
253249 if activation != "leaky_relu"
254250 else nn .LeakyReLU (NEGATIVE_SLOPE )
255251 )
256- conv_block = paddle . nn .Sequential (* conv_block )
252+ conv_block = nn .Sequential (* conv_block )
257253 self .conv_block_list .append (conv_block )
258254 patch_merge = cuboid_encoder .PatchMerging3D (
259255 dim = out_dim ,
@@ -303,7 +299,7 @@ def forward(self, x):
303299 return x
304300
305301
306- class FinalStackUpsamplingDecoder (paddle . nn .Layer ):
302+ class FinalStackUpsamplingDecoder (nn .Layer ):
307303 def __init__ (
308304 self ,
309305 target_shape_list : Tuple [Tuple [int , ...]],
@@ -326,8 +322,8 @@ def __init__(
326322 self .in_dim = in_dim
327323 self .num_conv_per_up_list = num_conv_per_up_list
328324 self .num_group_list = [max (1 , out_dim // 4 ) for out_dim in self .out_dim_list ]
329- self .conv_block_list = paddle . nn .LayerList ()
330- self .upsample_list = paddle . nn .LayerList ()
325+ self .conv_block_list = nn .LayerList ()
326+ self .upsample_list = nn .LayerList ()
331327 for i in range (self .num_upsample ):
332328 if i == 0 :
333329 in_dim = in_dim
@@ -349,15 +345,15 @@ def __init__(
349345 else :
350346 conv_in_dim = out_dim
351347 conv_block .append (
352- paddle . nn .Conv2D (
348+ nn .Conv2D (
353349 kernel_size = (3 , 3 ),
354350 padding = (1 , 1 ),
355351 in_channels = conv_in_dim ,
356352 out_channels = out_dim ,
357353 )
358354 )
359355 conv_block .append (
360- paddle . nn .GroupNorm (
356+ nn .GroupNorm (
361357 num_groups = self .num_group_list [i ], num_channels = out_dim
362358 )
363359 )
@@ -366,7 +362,7 @@ def __init__(
366362 if activation != "leaky_relu"
367363 else nn .LeakyReLU (NEGATIVE_SLOPE )
368364 )
369- conv_block = paddle . nn .Sequential (* conv_block )
365+ conv_block = nn .Sequential (* conv_block )
370366 self .conv_block_list .append (conv_block )
371367 self .reset_parameters ()
372368
@@ -686,7 +682,7 @@ def __init__(
686682 embed_dim = base_units , typ = pos_embed_type , maxH = H_in , maxW = W_in , maxT = T_in
687683 )
688684 mem_shapes = self .encoder .get_mem_shapes ()
689- self .z_proj = paddle . nn .Linear (
685+ self .z_proj = nn .Linear (
690686 in_features = mem_shapes [- 1 ][- 1 ], out_features = mem_shapes [- 1 ][- 1 ]
691687 )
692688 self .dec_pos_embed = cuboid_decoder .PosEmbed (
@@ -799,7 +795,7 @@ def get_initial_encoder_final_decoder(
799795 new_input_shape = self .initial_encoder .patch_merge .get_out_shape (
800796 self .input_shape
801797 )
802- self .dec_final_proj = paddle . nn .Linear (
798+ self .dec_final_proj = nn .Linear (
803799 in_features = self .base_units , out_features = C_out
804800 )
805801 elif self .initial_downsample_type == "stack_conv" :
@@ -839,7 +835,7 @@ def get_initial_encoder_final_decoder(
839835 linear_init_mode = self .down_up_linear_init_mode ,
840836 norm_init_mode = self .norm_init_mode ,
841837 )
842- self .dec_final_proj = paddle . nn .Linear (
838+ self .dec_final_proj = nn .Linear (
843839 in_features = dec_target_shape_list [- 1 ][- 1 ], out_features = C_out
844840 )
845841 new_input_shape = self .initial_encoder .get_out_shape_list (self .input_shape )[
@@ -892,7 +888,7 @@ def get_initial_z(self, final_mem, T_out):
892888 shape = [B , - 1 , - 1 , - 1 , - 1 ]
893889 )
894890 elif self .z_init_method == "nearest_interp" :
895- initial_z = paddle . nn .functional .interpolate (
891+ initial_z = nn .functional .interpolate (
896892 x = final_mem .transpose (perm = [0 , 4 , 1 , 2 , 3 ]),
897893 size = (T_out , final_mem .shape [2 ], final_mem .shape [3 ]),
898894 ).transpose (perm = [0 , 2 , 3 , 4 , 1 ])
0 commit comments