@@ -262,22 +262,6 @@ def is_enabled(self) -> bool:
262262 def is_disabled (self ) -> bool :
263263 return not self .is_enabled
264264
265- @property
266- def sanity_check_description (self ) -> str :
267- return "Validation Sanity Check"
268-
269- @property
270- def validation_description (self ) -> str :
271- return "Validation"
272-
273- @property
274- def test_description (self ) -> str :
275- return "Testing"
276-
277- @property
278- def predict_description (self ) -> str :
279- return "Predicting"
280-
281265 def _update_for_light_colab_theme (self ) -> None :
282266 if _detect_light_colab_theme ():
283267 attributes = ["description" , "batch_progress" , "metrics" ]
@@ -354,13 +338,28 @@ def on_train_epoch_start(self, trainer, pl_module):
354338 )
355339 self .refresh ()
356340
357- def on_validation_epoch_start (self , trainer , pl_module ):
341+ def on_validation_batch_start (
342+ self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
343+ ) -> None :
344+ if not self .has_dataloader_changed (dataloader_idx ):
345+ return
346+
358347 if trainer .sanity_checking :
359- self .val_sanity_progress_bar_id = self ._add_task (self .total_val_batches , self .sanity_check_description )
348+ if self .val_sanity_progress_bar_id is not None :
349+ self .progress .update (self .val_sanity_progress_bar_id , advance = 0 , visible = False )
350+
351+ self .val_sanity_progress_bar_id = self ._add_task (
352+ self .total_val_batches_current_dataloader , self .sanity_check_description , visible = False
353+ )
360354 else :
355+ if self .val_progress_bar_id is not None :
356+ self .progress .update (self .val_progress_bar_id , advance = 0 , visible = False )
357+
358+ # TODO: remove old tasks when new onces are created
361359 self .val_progress_bar_id = self ._add_task (
362- self .total_val_batches , self .validation_description , visible = False
360+ self .total_val_batches_current_dataloader , self .validation_description , visible = False
363361 )
362+
364363 self .refresh ()
365364
366365 def _add_task (self , total_batches : int , description : str , visible : bool = True ) -> Optional [int ]:
@@ -387,13 +386,36 @@ def on_validation_epoch_end(self, trainer, pl_module):
387386 def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
388387 if trainer .state .fn == "fit" :
389388 self ._update_metrics (trainer , pl_module )
389+ self .reset_dataloader_idx_tracker ()
390+
391+ def on_test_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
392+ self .reset_dataloader_idx_tracker ()
390393
391- def on_test_epoch_start (self , trainer , pl_module ):
392- self .test_progress_bar_id = self ._add_task (self .total_test_batches , self .test_description )
394+ def on_predict_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
395+ self .reset_dataloader_idx_tracker ()
396+
397+ def on_test_batch_start (
398+ self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
399+ ) -> None :
400+ if not self .has_dataloader_changed (dataloader_idx ):
401+ return
402+
403+ if self .test_progress_bar_id is not None :
404+ self .progress .update (self .test_progress_bar_id , advance = 0 , visible = False )
405+ self .test_progress_bar_id = self ._add_task (self .total_test_batches_current_dataloader , self .test_description )
393406 self .refresh ()
394407
395- def on_predict_epoch_start (self , trainer , pl_module ):
396- self .predict_progress_bar_id = self ._add_task (self .total_predict_batches , self .predict_description )
408+ def on_predict_batch_start (
409+ self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
410+ ) -> None :
411+ if not self .has_dataloader_changed (dataloader_idx ):
412+ return
413+
414+ if self .predict_progress_bar_id is not None :
415+ self .progress .update (self .predict_progress_bar_id , advance = 0 , visible = False )
416+ self .predict_progress_bar_id = self ._add_task (
417+ self .total_predict_batches_current_dataloader , self .predict_description
418+ )
397419 self .refresh ()
398420
399421 def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
@@ -406,20 +428,23 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
406428
407429 def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
408430 if trainer .sanity_checking :
409- self ._update (self .val_sanity_progress_bar_id , self .val_batch_idx , self .total_val_batches )
431+ self ._update (self .val_sanity_progress_bar_id , self .val_batch_idx , self .total_val_batches_current_dataloader )
410432 elif self .val_progress_bar_id is not None :
411433 # check to see if we should update the main training progress bar
412434 if self .main_progress_bar_id is not None :
413- self ._update (self .main_progress_bar_id , self .val_batch_idx , self .total_val_batches )
414- self ._update (self .val_progress_bar_id , self .val_batch_idx , self .total_val_batches )
435+ # TODO: Use total val_processed here just like TQDM in a follow-up
436+ self ._update (self .main_progress_bar_id , self .val_batch_idx , self .total_val_batches_current_dataloader )
437+ self ._update (self .val_progress_bar_id , self .val_batch_idx , self .total_val_batches_current_dataloader )
415438 self .refresh ()
416439
417440 def on_test_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
418- self ._update (self .test_progress_bar_id , self .test_batch_idx , self .total_test_batches )
441+ self ._update (self .test_progress_bar_id , self .test_batch_idx , self .total_test_batches_current_dataloader )
419442 self .refresh ()
420443
421444 def on_predict_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
422- self ._update (self .predict_progress_bar_id , self .predict_batch_idx , self .total_predict_batches )
445+ self ._update (
446+ self .predict_progress_bar_id , self .predict_batch_idx , self .total_predict_batches_current_dataloader
447+ )
423448 self .refresh ()
424449
425450 def _get_train_description (self , current_epoch : int ) -> str :
0 commit comments