Skip to content
20 changes: 11 additions & 9 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import pytest
import torch

from pytorch_lightning import callbacks, seed_everything, Trainer
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -93,7 +94,7 @@ def training_step(self, batch, batch_idx):

model = TestModel()
trainer = Trainer(
callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k)],
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k)],
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
Expand All @@ -107,8 +108,9 @@ def training_step(self, batch, batch_idx):

@mock.patch('torch.save')
@RunIf(special=True, min_gpus=2)
@pytest.mark.parametrize("accelerator", ["ddp", pytest.param("horovod", marks=RunIf(horovod=True, skip_windows=True))])
@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)])
def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected):
def test_top_k_distributed(save_mock, tmpdir, accelerator, k, epochs, val_check_interval, expected):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not include horovod in the parameterization here yet, otherwise we risk getting another flaky test. I believe we need to improve our testing integration with horovod first. #6935

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the main fix of this PR was merged, I suggest to close this one


class TestModel(BoringModel):

Expand All @@ -117,21 +119,21 @@ def training_step(self, batch, batch_idx):
self.log('my_loss', batch_idx * (1 + local_rank), on_epoch=True)
return super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs) -> None:
def training_epoch_end(self, outputs):
data = str(self.global_rank)
obj = [[data], (data, ), set(data)]
obj = [[data], (data, ), {data}]
out = self.trainer.training_type_plugin.broadcast(obj)
assert obj == [[str(self.global_rank)], (str(self.global_rank), ), set(str(self.global_rank))]
assert out == [['0'], ('0', ), set('0')]
assert obj == [[str(self.global_rank)], (str(self.global_rank), ), {str(self.global_rank)}]
assert out == [['0'], ('0', ), {'0'}]

model = TestModel()
trainer = Trainer(
callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss_step', save_top_k=k, mode="max")],
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='my_loss_step', save_top_k=k, mode="max")],
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval,
accelerator="ddp",
accelerator=accelerator,
gpus=2,
limit_train_batches=64,
limit_val_batches=32,
Expand Down