diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 0c4f491186f83..273c60a459fd4 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -43,7 +43,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446)) -- +- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470)) + - diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 71a85e431bb7d..a2d73d83184b1 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -93,21 +93,10 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> dtype = _plugin_to_compute_dtype(trainer.precision_plugin) self.available_flops = get_available_flops(trainer.strategy.root_device, dtype) - if stage == TrainerFn.FITTING: - if trainer.accumulate_grad_batches % trainer.log_every_n_steps != 0: - raise ValueError( - "The `ThroughputMonitor` only logs when gradient accumulation is finished. You set" - f" `Trainer(accumulate_grad_batches={trainer.accumulate_grad_batches}," - f" log_every_n_steps={trainer.log_every_n_steps})` but these are not divisible and thus will not" - " log anything." - ) - - if trainer.enable_validation: - # `fit` includes validation inside - throughput = Throughput( - available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs - ) - self._throughputs[RunningStage.VALIDATING] = throughput + if stage == TrainerFn.FITTING and trainer.enable_validation: + # `fit` includes validation inside + throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs) + self._throughputs[RunningStage.VALIDATING] = throughput throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs) stage = trainer.state.stage diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 74be939da671c..9467e45e2fa80 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -160,7 +160,8 @@ def test_throughput_monitor_fit_no_length_fn(tmp_path): ] -def test_throughput_monitor_fit_gradient_accumulation(tmp_path): +@pytest.mark.parametrize("log_every_n_steps", [1, 3]) +def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_path): logger_mock = Mock() logger_mock.save_dir = tmp_path monitor = ThroughputMonitor(length_fn=lambda x: 3 * 2, batch_size_fn=lambda x: 3, window_size=4, separator="|") @@ -174,26 +175,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path): limit_train_batches=5, limit_val_batches=0, max_epochs=2, - log_every_n_steps=3, + log_every_n_steps=log_every_n_steps, accumulate_grad_batches=2, - num_sanity_val_steps=2, - enable_checkpointing=False, - enable_model_summary=False, - enable_progress_bar=False, - ) - with pytest.raises(ValueError, match="not divisible"): - trainer.fit(model) - - trainer = Trainer( - devices=1, - logger=logger_mock, - callbacks=monitor, - limit_train_batches=5, - limit_val_batches=0, - max_epochs=2, - log_every_n_steps=1, - accumulate_grad_batches=2, - num_sanity_val_steps=2, enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, @@ -211,9 +194,19 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path): "train|device|flops_per_sec": 10.0, "train|device|mfu": 0.1, } - assert logger_mock.log_metrics.mock_calls == [ + + all_log_calls = [ call( - metrics={"train|time": 2.5, "train|batches": 2, "train|samples": 6, "train|lengths": 12, "epoch": 0}, step=0 + metrics={ + # The very first batch doesn't have the *_per_sec metrics yet + **(expected if log_every_n_steps > 1 else {}), + "train|time": 2.5, + "train|batches": 2, + "train|samples": 6, + "train|lengths": 12, + "epoch": 0, + }, + step=0, ), call( metrics={ @@ -271,6 +264,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path): step=5, ), ] + expected_log_calls = all_log_calls[(log_every_n_steps - 1) :: log_every_n_steps] + assert logger_mock.log_metrics.mock_calls == expected_log_calls @pytest.mark.parametrize("fn", ["validate", "test", "predict"])