Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix log_every_n_steps check in ThroughputMonitor #19470

Merged
merged 9 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))


-
- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470))


-

Expand Down
21 changes: 6 additions & 15 deletions src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,12 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
dtype = _plugin_to_compute_dtype(trainer.precision_plugin)
self.available_flops = get_available_flops(trainer.strategy.root_device, dtype)

if stage == TrainerFn.FITTING:
if trainer.accumulate_grad_batches % trainer.log_every_n_steps != 0:
raise ValueError(
"The `ThroughputMonitor` only logs when gradient accumulation is finished. You set"
f" `Trainer(accumulate_grad_batches={trainer.accumulate_grad_batches},"
f" log_every_n_steps={trainer.log_every_n_steps})` but these are not divisible and thus will not"
" log anything."
)

if trainer.enable_validation:
# `fit` includes validation inside
throughput = Throughput(
available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs
)
self._throughputs[RunningStage.VALIDATING] = throughput
if stage == TrainerFn.FITTING and trainer.enable_validation:
# `fit` includes validation inside
throughput = Throughput(
available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs
)
self._throughputs[RunningStage.VALIDATING] = throughput

throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs)
stage = trainer.state.stage
Expand Down
39 changes: 17 additions & 22 deletions tests/tests_pytorch/callbacks/test_throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def test_throughput_monitor_fit_no_length_fn(tmp_path):
]


def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
@pytest.mark.parametrize("log_every_n_steps", [1, 3])
def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_path):
logger_mock = Mock()
logger_mock.save_dir = tmp_path
monitor = ThroughputMonitor(length_fn=lambda x: 3 * 2, batch_size_fn=lambda x: 3, window_size=4, separator="|")
Expand All @@ -174,26 +175,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
limit_train_batches=5,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=3,
log_every_n_steps=log_every_n_steps,
accumulate_grad_batches=2,
num_sanity_val_steps=2,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
with pytest.raises(ValueError, match="not divisible"):
trainer.fit(model)

trainer = Trainer(
devices=1,
logger=logger_mock,
callbacks=monitor,
limit_train_batches=5,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=1,
accumulate_grad_batches=2,
num_sanity_val_steps=2,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
Expand All @@ -211,9 +194,19 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
"train|device|flops_per_sec": 10.0,
"train|device|mfu": 0.1,
}
assert logger_mock.log_metrics.mock_calls == [

all_log_calls = [
call(
metrics={"train|time": 2.5, "train|batches": 2, "train|samples": 6, "train|lengths": 12, "epoch": 0}, step=0
metrics={
# The very first batch doesn't have the *_per_sec metrics yet
**(expected if log_every_n_steps > 1 else {}),
"train|time": 2.5,
"train|batches": 2,
"train|samples": 6,
"train|lengths": 12,
"epoch": 0,
},
step=0,
),
call(
metrics={
Expand Down Expand Up @@ -271,6 +264,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
step=5,
),
]
expected_log_calls = all_log_calls[(log_every_n_steps - 1) :: log_every_n_steps]
assert logger_mock.log_metrics.mock_calls == expected_log_calls


@pytest.mark.parametrize("fn", ["validate", "test", "predict"])
Expand Down
Loading