2020"""
2121import time
2222
23+ import pytorch_lightning as pl
2324from pytorch_lightning .callbacks .base import Callback
2425from pytorch_lightning .utilities import _AcceleratorType , _TPU_AVAILABLE , rank_zero_deprecation , rank_zero_info
2526from pytorch_lightning .utilities .exceptions import MisconfigurationException
@@ -66,7 +67,7 @@ def __init__(self, verbose: bool = True) -> None:
6667
6768 self ._verbose = verbose
6869
69- def on_train_start (self , trainer , pl_module ) -> None :
70+ def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
7071 if not trainer .logger :
7172 raise MisconfigurationException ("Cannot use XLAStatsMonitor callback with Trainer that has no logger." )
7273
@@ -80,11 +81,13 @@ def on_train_start(self, trainer, pl_module) -> None:
8081 total_memory = trainer .strategy .reduce (memory_info ["kb_total" ]) * 0.001
8182 rank_zero_info (f"Average Total memory: { total_memory :.2f} MB" )
8283
83- def on_train_epoch_start (self , trainer , pl_module ) -> None :
84+ def on_train_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
8485 self ._start_time = time .time ()
8586
86- def on_train_epoch_end (self , trainer , pl_module ) -> None :
87- logs = {}
87+ def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
88+ if not trainer .logger :
89+ raise MisconfigurationException ("Cannot use XLAStatsMonitor callback with Trainer that has no logger." )
90+
8891 memory_info = xm .get_memory_info (pl_module .device )
8992 epoch_time = time .time () - self ._start_time
9093
@@ -95,9 +98,10 @@ def on_train_epoch_end(self, trainer, pl_module) -> None:
9598 peak_memory = trainer .strategy .reduce (peak_memory ) * 0.001
9699 epoch_time = trainer .strategy .reduce (epoch_time )
97100
98- logs ["avg. free memory (MB)" ] = free_memory
99- logs ["avg. peak memory (MB)" ] = peak_memory
100- trainer .logger .log_metrics (logs , step = trainer .current_epoch )
101+ trainer .logger .log_metrics (
102+ {"avg. free memory (MB)" : float (free_memory ), "avg. peak memory (MB)" : float (peak_memory )},
103+ step = trainer .current_epoch ,
104+ )
101105
102106 if self ._verbose :
103107 rank_zero_info (f"Average Epoch time: { epoch_time :.2f} seconds" )
0 commit comments