diff --git a/credit/trainers/trainerERA5_v2.py b/credit/trainers/trainerERA5_v2.py index bb4a365..e4c9fe8 100644 --- a/credit/trainers/trainerERA5_v2.py +++ b/credit/trainers/trainerERA5_v2.py @@ -26,7 +26,7 @@ from credit.data import concat_and_reshape, reshape_only from credit.models.checkpoint import TorchFSDPCheckpointIO from credit.scheduler import update_on_batch, update_on_epoch -from credit.trainers.utils import cleanup, accum_log +from credit.trainers.utils import cleanup, accum_log, cycle from credit.trainers.base_trainer import BaseTrainer from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer @@ -113,7 +113,7 @@ def train_one_epoch( ) batch_group_generator = tqdm.tqdm( - enumerate(trainloader), + range(batches_per_epoch), total=batches_per_epoch, leave=True, disable=True if self.rank > 0 else False, @@ -121,8 +121,9 @@ def train_one_epoch( results_dict = defaultdict(list) - for i, batch in batch_group_generator: - + dl = cycle(trainloader) + for i in batch_group_generator: + batch = next(dl) # Get the next batch from the iterator # training log logs = {} # loss @@ -405,13 +406,15 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): ) batch_group_generator = tqdm.tqdm( - enumerate(valid_loader), + range(valid_batches_per_epoch), total=valid_batches_per_epoch, leave=True, disable=True if self.rank > 0 else False, ) - for i, batch in batch_group_generator: + dl = cycle(valid_loader) + for i in batch_group_generator: + batch = next(dl) with torch.no_grad(): if "x_surf" in batch: # combine x and x_surf