From 38fdef5eea78d948a039fcbbe5f5465e03afa9c5 Mon Sep 17 00:00:00 2001 From: anonymousdouble <112695649+anonymousdouble@users.noreply.github.com> Date: Mon, 1 Jan 2024 21:46:49 +1100 Subject: [PATCH] Update test_flop_count.py refactor with chain constant value assignment to make code more Pythonic, concise and efficient. --- tests/test_flop_count.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/test_flop_count.py b/tests/test_flop_count.py index ad593b9..c27c9bf 100644 --- a/tests/test_flop_count.py +++ b/tests/test_flop_count.py @@ -724,10 +724,8 @@ def test_batchnorm(self) -> None: ) # Test for BatchNorm2d. - batch_size = 10 - input_dim = 10 - spatial_dim_x = 5 - spatial_dim_y = 5 + batch_size = input_dim = 10 + spatial_dim_x = spatial_dim_y = 5 batch_2d = nn.BatchNorm2d(input_dim, affine=False) x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y) flop_dict, _ = flop_count(batch_2d, (x,)) @@ -739,11 +737,8 @@ def test_batchnorm(self) -> None: ) # Test for BatchNorm3d. - batch_size = 10 - input_dim = 10 - spatial_dim_x = 5 - spatial_dim_y = 5 - spatial_dim_z = 5 + batch_size = input_dim = 10 + spatial_dim_x = spatial_dim_y = spatial_dim_z = 5 batch_3d = nn.BatchNorm3d(input_dim, affine=False) x = torch.randn( batch_size, input_dim, spatial_dim_x, spatial_dim_y, spatial_dim_z