Skip to content

Commit

Permalink
Added fixed random seed to not depend of randomness of initialized we…
Browse files Browse the repository at this point in the history
…ights (#1839)
  • Loading branch information
BloodAxe authored Feb 15, 2024
1 parent a6f998c commit b164da5
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tests/unit_tests/repvgg_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,25 @@ def test_deployment_architecture(self):
"""
image_size = 224
in_channels = 3

for arch_name in ARCHITECTURES:
# skip custom constructors to keep all_arch_params as general as a possible
if "repvgg" not in arch_name or "custom" in arch_name:
continue

with self.subTest(arch_name=arch_name):
# Set the seed to 0 to ensure that the model is initialized with the same weights
torch.manual_seed(0)
model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params)
self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode
self.assertTrue(model.build_residual_branches)

training_mode_sd = model.state_dict()
for module in training_mode_sd:
self.assertFalse("reparam" in module) # deployment block included in training mode
test_input = torch.ones((1, in_channels, image_size, image_size))

# Initializing input with 0.1 instead of 1.0 to move mean of input closer to 0
test_input = torch.ones((1, in_channels, image_size, image_size)) * 0.1
model.eval()
training_mode_output = model(test_input)

Expand All @@ -83,6 +88,8 @@ def test_backbone_mode(self):
"""
image_size = 224
in_channels = 3
# Set the seed to 0 to ensure that the model is initialized with the same weights
torch.manual_seed(0)
test_input = torch.rand((1, in_channels, image_size, image_size))
backbone_model = RepVggA1(self.backbone_arch_params)
model = BackboneBasedModel(backbone_model, backbone_output_channel=1280, num_classes=self.backbone_arch_params.num_classes)
Expand Down

0 comments on commit b164da5

Please sign in to comment.