@@ -77,3 +77,73 @@ def weight_generator():
7777 assert torch .all (
7878 new_mod .nested_mod .bn .running_var == mod .nested_mod .bn .running_var )
7979 assert new_mod .nested_mod .bn .num_batches_tracked .item () == 1
80+
81+
82+ def test_module_skip_prefix ():
83+ """Ensure the auto weight loader can skip prefix."""
84+ mod = ModuleWithNestedBatchNorm ()
85+ # Run some data through the module with batchnorm
86+ mod (torch .Tensor ([[1 , 2 ], [3 , 4 ]]))
87+
88+ # Try to load the weights to a new instance
89+ def weight_generator ():
90+ # weights needed to be filtered out
91+ redundant_weights = {
92+ "prefix.bn.weight" : torch .Tensor ([1 , 2 ]),
93+ "prefix.bn.bias" : torch .Tensor ([3 , 4 ]),
94+ }
95+ yield from (mod .state_dict () | redundant_weights ).items ()
96+
97+ new_mod = ModuleWithNestedBatchNorm ()
98+
99+ assert not torch .all (
100+ new_mod .nested_mod .bn .running_mean == mod .nested_mod .bn .running_mean )
101+ assert not torch .all (
102+ new_mod .nested_mod .bn .running_var == mod .nested_mod .bn .running_var )
103+ assert new_mod .nested_mod .bn .num_batches_tracked .item () == 0
104+
105+ loader = AutoWeightsLoader (new_mod , skip_prefixes = ["prefix." ])
106+ loader .load_weights (weight_generator ())
107+
108+ # Ensure the stats are updated
109+ assert torch .all (
110+ new_mod .nested_mod .bn .running_mean == mod .nested_mod .bn .running_mean )
111+ assert torch .all (
112+ new_mod .nested_mod .bn .running_var == mod .nested_mod .bn .running_var )
113+ assert new_mod .nested_mod .bn .num_batches_tracked .item () == 1
114+
115+
116+ def test_module_skip_substr ():
117+ """Ensure the auto weight loader can skip prefix."""
118+ mod = ModuleWithNestedBatchNorm ()
119+ # Run some data through the module with batchnorm
120+ mod (torch .Tensor ([[1 , 2 ], [3 , 4 ]]))
121+
122+ # Try to load the weights to a new instance
123+ def weight_generator ():
124+ # weights needed to be filtered out
125+ redundant_weights = {
126+ "nested_mod.0.substr.weight" : torch .Tensor ([1 , 2 ]),
127+ "nested_mod.0.substr.bias" : torch .Tensor ([3 , 4 ]),
128+ "nested_mod.substr.weight" : torch .Tensor ([1 , 2 ]),
129+ "nested_mod.substr.bias" : torch .Tensor ([3 , 4 ]),
130+ }
131+ yield from (mod .state_dict () | redundant_weights ).items ()
132+
133+ new_mod = ModuleWithNestedBatchNorm ()
134+
135+ assert not torch .all (
136+ new_mod .nested_mod .bn .running_mean == mod .nested_mod .bn .running_mean )
137+ assert not torch .all (
138+ new_mod .nested_mod .bn .running_var == mod .nested_mod .bn .running_var )
139+ assert new_mod .nested_mod .bn .num_batches_tracked .item () == 0
140+
141+ loader = AutoWeightsLoader (new_mod , skip_substrs = ["substr." ])
142+ loader .load_weights (weight_generator ())
143+
144+ # Ensure the stats are updated
145+ assert torch .all (
146+ new_mod .nested_mod .bn .running_mean == mod .nested_mod .bn .running_mean )
147+ assert torch .all (
148+ new_mod .nested_mod .bn .running_var == mod .nested_mod .bn .running_var )
149+ assert new_mod .nested_mod .bn .num_batches_tracked .item () == 1
0 commit comments