Skip to content

Commit c8db9d8

Browse files
committed
Refactor SWA batch norm moment update to work with validation
1 parent 1696273 commit c8db9d8

File tree

2 files changed

+67
-55
lines changed

2 files changed

+67
-55
lines changed

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def __init__(
120120
if device is not None and not isinstance(device, (torch.device, str)):
121121
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")
122122

123-
self.momenta = None
124123
self.n_averaged = None
125124
self._swa_epoch_start = swa_epoch_start
126125
self._swa_lrs = swa_lrs
@@ -134,6 +133,7 @@ def __init__(
134133
self._temp_model = None
135134
self._initialized = False
136135
self._swa_scheduler = None
136+
self._batch_norm_moments = None
137137

138138
@property
139139
def swa_start(self) -> int:
@@ -171,12 +171,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
171171
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)
172172

173173
self._max_epochs = trainer.max_epochs
174-
if self._model_contains_batch_norm:
175-
# virtually increase max_epochs to perform batch norm update on latest epoch.
176-
trainer.fit_loop.max_epochs += 1
177174

178175
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
179-
resuming_after_start = trainer.current_epoch > self.swa_start and not self._initialized
176+
resuming_after_start = (not self._initialized) and (self.swa_start < trainer.current_epoch <= self.swa_end)
180177
if trainer.current_epoch == self.swa_start or resuming_after_start:
181178
self._initialized = True
182179

@@ -223,75 +220,73 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
223220
if self.swa_start <= trainer.current_epoch <= self.swa_end:
224221
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)
225222

226-
# Note: No > here in case the callback is saved with the model and training continues
227-
if trainer.current_epoch == self.swa_end + 1:
228-
229-
# Transfer weights from average model to pl_module
230-
self.transfer_weights(self._average_model, pl_module)
231-
232-
# Reset BatchNorm for update
233-
self.reset_batch_norm_and_save_state(pl_module)
234-
235-
# There is no need to perform either backward or optimizer.step as we are
236-
# performing only one pass over the train data-loader to compute activation statistics
237-
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
238-
trainer.num_training_batches += 1
239-
trainer.fit_loop._skip_backward = True
240-
self._accumulate_grad_batches = trainer.accumulate_grad_batches
241-
242-
trainer.accumulate_grad_batches = trainer.num_training_batches
243-
244-
def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
245-
trainer.fit_loop._skip_backward = False
246-
247223
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
248-
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
249-
# BatchNorm epoch update. Reset state
250-
trainer.accumulate_grad_batches = self._accumulate_grad_batches
251-
trainer.num_training_batches -= 1
252-
trainer.fit_loop.max_epochs -= 1
253-
self.reset_momenta()
254-
elif trainer.current_epoch == self.swa_end:
224+
if trainer.current_epoch == self.swa_end:
255225
# Last SWA epoch. Transfer weights from average model to pl_module
256226
self.transfer_weights(self._average_model, pl_module)
227+
if self._model_contains_batch_norm:
228+
self._update_batch_norm_moments(trainer, pl_module, store_moments=False)
257229

258230
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
259231
if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end):
260232
# Take a temporary copy of the model parameters
261233
self.transfer_weights(pl_module, self._temp_model)
262234
# Update the model with the averaged parameters
263235
self.transfer_weights(self._average_model, pl_module)
236+
if self._model_contains_batch_norm:
237+
self._update_batch_norm_moments(trainer, pl_module)
264238

265239
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
266240
if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end):
267241
# Copy original model parameters back
268242
self.transfer_weights(self._temp_model, pl_module)
243+
if self._model_contains_batch_norm:
244+
self._restore_batch_norm_moments()
269245

270246
@staticmethod
271247
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"):
272248
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
273249
dst_param.detach().copy_(src_param.to(dst_param.device))
274250

275-
def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"):
276-
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
277-
self.momenta = {}
251+
def _update_batch_norm_moments(
252+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", store_moments: bool = True
253+
):
254+
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166."""
255+
prev_momenta = {}
256+
self._batch_norm_moments = {}
257+
258+
was_training = pl_module.training
259+
pl_module.train()
260+
278261
for module in pl_module.modules():
279262
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
280263
continue
264+
prev_momenta[module] = module.momentum
265+
if store_moments:
266+
self._batch_norm_moments[module] = (module.running_mean, module.running_var)
281267
module.running_mean = torch.zeros_like(
282268
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
283269
)
284270
module.running_var = torch.ones_like(
285271
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
286272
)
287-
self.momenta[module] = module.momentum
288273
module.momentum = None
289274
module.num_batches_tracked *= 0
290275

291-
def reset_momenta(self):
292-
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
293-
for bn_module in self.momenta:
294-
bn_module.momentum = self.momenta[bn_module]
276+
# Recompute mean and variance for all batch norm layers by doing a full pass over the training data
277+
for batch, _ in trainer.data_connector.train_data_fetcher:
278+
batch = batch.to(pl_module.device)
279+
pl_module(batch)
280+
281+
# Reset model state
282+
for bn_module, momenta in prev_momenta.items():
283+
bn_module.momentum = momenta
284+
pl_module.train(was_training)
285+
286+
def _restore_batch_norm_moments(self):
287+
for bn_module, (mean, variance) in self._batch_norm_moments.items():
288+
bn_module.running_mean = mean
289+
bn_module.running_var = variance
295290

296291
@staticmethod
297292
def update_parameters(
@@ -317,7 +312,6 @@ def on_save_checkpoint(
317312
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
318313
) -> dict:
319314
checkpoint_data = {
320-
"momenta": self.momenta,
321315
"n_averaged": self.n_averaged,
322316
"swa_lrs": self._swa_lrs,
323317
"annealing_epochs": self._annealing_epochs,
@@ -330,7 +324,6 @@ def on_load_checkpoint(
330324
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
331325
) -> None:
332326
if callback_state:
333-
self.momenta = callback_state["momenta"]
334327
self.n_averaged = callback_state["n_averaged"]
335328
self._swa_lrs = callback_state["swa_lrs"]
336329
self._annealing_strategy = callback_state["annealing_strategy"]

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,21 @@ def training_step(self, batch, batch_idx):
5252
loss = self.loss(batch, output)
5353
return {"loss": loss}
5454

55+
def validation_step(self, batch, batch_idx):
56+
output = self.forward(batch)
57+
loss = self.loss(batch, output)
58+
return {"x": loss}
59+
5560
def train_dataloader(self):
5661

5762
dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
5863
dset = dset_cls(32, 64)
5964

6065
return DataLoader(dset, batch_size=2)
6166

67+
def val_dataloader(self):
68+
return self.train_dataloader()
69+
6270
def configure_optimizers(self):
6371
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
6472
return {
@@ -86,20 +94,24 @@ def __init__(self, *args, **kwargs):
8694
self.resuming_from_epoch = 0
8795
super().__init__(*args, **kwargs)
8896

97+
validation_calls: int = 0
8998
update_parameters_calls: int = 0
9099
transfer_weights_calls: int = 0
91100

92101
def update_parameters(self, *args, **kwargs):
93102
self.update_parameters_calls += 1
94103
return StochasticWeightAveraging.update_parameters(*args, **kwargs)
95104

105+
def on_validation_start(self, *args, **kwargs):
106+
self.validation_calls += 1
107+
return super().on_validation_start(*args, **kwargs)
108+
96109
def transfer_weights(self, *args, **kwargs):
97110
self.transfer_weights_calls += 1
98111
return StochasticWeightAveraging.transfer_weights(*args, **kwargs)
99112

100113
def on_train_epoch_start(self, trainer, *args):
101114
super().on_train_epoch_start(trainer, *args)
102-
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
103115
if self.swa_start <= trainer.current_epoch:
104116
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
105117
assert trainer.lr_schedulers[0]["interval"] == "epoch"
@@ -116,11 +128,6 @@ def on_train_epoch_end(self, trainer, *args):
116128
def on_train_end(self, trainer, pl_module):
117129
super().on_train_end(trainer, pl_module)
118130

119-
# make sure these are correctly set again
120-
assert not trainer.fit_loop._skip_backward
121-
assert trainer.accumulate_grad_batches == 2
122-
assert trainer.num_training_batches == 5
123-
124131
if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin):
125132
# check backward call count. the batchnorm update epoch should not backward
126133
assert trainer.accelerator.backward.call_count == (
@@ -133,16 +140,27 @@ def on_train_end(self, trainer, pl_module):
133140
else:
134141
expected_update_calls = trainer.max_epochs - (self._swa_epoch_start - 1)
135142
assert self.update_parameters_calls == expected_update_calls
136-
assert self.transfer_weights_calls == 1
143+
if self._swa_validation:
144+
# 3 weight transfers are needed per SWA validation step
145+
assert self.transfer_weights_calls == (self.validation_calls - self._swa_epoch_start) * 3 + 1
146+
else:
147+
assert self.transfer_weights_calls == 1
137148

138149

139150
def train_with_swa(
140-
tmpdir, batchnorm=True, strategy=None, gpus=None, num_processes=1, interval="epoch", iterable_dataset=False
151+
tmpdir,
152+
batchnorm=True,
153+
strategy=None,
154+
gpus=None,
155+
num_processes=1,
156+
interval="epoch",
157+
iterable_dataset=False,
158+
validation=False,
141159
):
142160
model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset)
143161
swa_start = 2
144162
max_epochs = 5
145-
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
163+
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=validation)
146164
assert swa_callback.update_parameters_calls == 0
147165
assert swa_callback.transfer_weights_calls == 0
148166

@@ -151,7 +169,7 @@ def train_with_swa(
151169
enable_progress_bar=False,
152170
max_epochs=max_epochs,
153171
limit_train_batches=5,
154-
limit_val_batches=0,
172+
limit_val_batches=1.0 if validation else 0.0,
155173
callbacks=[swa_callback],
156174
accumulate_grad_batches=2,
157175
strategy=strategy,
@@ -188,8 +206,9 @@ def test_swa_callback_1_gpu(tmpdir):
188206

189207
@pytest.mark.parametrize("batchnorm", (True, False))
190208
@pytest.mark.parametrize("iterable_dataset", (True, False))
191-
def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool):
192-
train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset)
209+
@pytest.mark.parametrize("validation", (True, False))
210+
def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool, validation: bool):
211+
train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset, validation=validation)
193212

194213

195214
@pytest.mark.parametrize("interval", ("epoch", "step"))

0 commit comments

Comments
 (0)