Skip to content

Commit 3a2df4f

Browse files
awaelchlicarmocca
andauthored
Fix typing in pl.callbacks.xla_stats_monitor (#11219)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
1 parent 9b873dc commit 3a2df4f

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ module = [
5151
"pytorch_lightning.callbacks.progress.tqdm_progress",
5252
"pytorch_lightning.callbacks.quantization",
5353
"pytorch_lightning.callbacks.stochastic_weight_avg",
54-
"pytorch_lightning.callbacks.xla_stats_monitor",
5554
"pytorch_lightning.core.datamodule",
5655
"pytorch_lightning.core.decorators",
5756
"pytorch_lightning.core.lightning",

pytorch_lightning/callbacks/xla_stats_monitor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121
import time
2222

23+
import pytorch_lightning as pl
2324
from pytorch_lightning.callbacks.base import Callback
2425
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
2526
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -66,7 +67,7 @@ def __init__(self, verbose: bool = True) -> None:
6667

6768
self._verbose = verbose
6869

69-
def on_train_start(self, trainer, pl_module) -> None:
70+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
7071
if not trainer.logger:
7172
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
7273

@@ -80,11 +81,13 @@ def on_train_start(self, trainer, pl_module) -> None:
8081
total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
8182
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
8283

83-
def on_train_epoch_start(self, trainer, pl_module) -> None:
84+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
8485
self._start_time = time.time()
8586

86-
def on_train_epoch_end(self, trainer, pl_module) -> None:
87-
logs = {}
87+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
88+
if not trainer.logger:
89+
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
90+
8891
memory_info = xm.get_memory_info(pl_module.device)
8992
epoch_time = time.time() - self._start_time
9093

@@ -95,9 +98,10 @@ def on_train_epoch_end(self, trainer, pl_module) -> None:
9598
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
9699
epoch_time = trainer.strategy.reduce(epoch_time)
97100

98-
logs["avg. free memory (MB)"] = free_memory
99-
logs["avg. peak memory (MB)"] = peak_memory
100-
trainer.logger.log_metrics(logs, step=trainer.current_epoch)
101+
trainer.logger.log_metrics(
102+
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
103+
step=trainer.current_epoch,
104+
)
101105

102106
if self._verbose:
103107
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")

0 commit comments

Comments
 (0)