diff --git a/tests/test_flop_count.py b/tests/test_flop_count.py index ad593b9..c27c9bf 100644 --- a/tests/test_flop_count.py +++ b/tests/test_flop_count.py @@ -724,10 +724,8 @@ def test_batchnorm(self) -> None: ) # Test for BatchNorm2d. - batch_size = 10 - input_dim = 10 - spatial_dim_x = 5 - spatial_dim_y = 5 + batch_size = input_dim = 10 + spatial_dim_x = spatial_dim_y = 5 batch_2d = nn.BatchNorm2d(input_dim, affine=False) x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y) flop_dict, _ = flop_count(batch_2d, (x,)) @@ -739,11 +737,8 @@ def test_batchnorm(self) -> None: ) # Test for BatchNorm3d. - batch_size = 10 - input_dim = 10 - spatial_dim_x = 5 - spatial_dim_y = 5 - spatial_dim_z = 5 + batch_size = input_dim = 10 + spatial_dim_x = spatial_dim_y = spatial_dim_z = 5 batch_3d = nn.BatchNorm3d(input_dim, affine=False) x = torch.randn( batch_size, input_dim, spatial_dim_x, spatial_dim_y, spatial_dim_z