Skip to content

Commit

Permalink
Fixed tqdm bug and tested this trainer against grad-accum for single …
Browse files Browse the repository at this point in the history
…step
  • Loading branch information
jsschreck committed Dec 26, 2024
1 parent f59ff3c commit d50849e
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions credit/trainers/trainerERA5_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from credit.data import concat_and_reshape, reshape_only
from credit.models.checkpoint import TorchFSDPCheckpointIO
from credit.scheduler import update_on_batch, update_on_epoch
from credit.trainers.utils import cleanup, accum_log
from credit.trainers.utils import cleanup, accum_log, cycle
from credit.trainers.base_trainer import BaseTrainer
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer

Expand Down Expand Up @@ -113,16 +113,17 @@ def train_one_epoch(
)

batch_group_generator = tqdm.tqdm(
enumerate(trainloader),
range(batches_per_epoch),
total=batches_per_epoch,
leave=True,
disable=True if self.rank > 0 else False,
)

results_dict = defaultdict(list)

for i, batch in batch_group_generator:

dl = cycle(trainloader)
for i in batch_group_generator:
batch = next(dl) # Get the next batch from the iterator
# training log
logs = {}
# loss
Expand Down Expand Up @@ -405,13 +406,15 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics):
)

batch_group_generator = tqdm.tqdm(
enumerate(valid_loader),
range(valid_batches_per_epoch),
total=valid_batches_per_epoch,
leave=True,
disable=True if self.rank > 0 else False,
)

for i, batch in batch_group_generator:
dl = cycle(valid_loader)
for i in batch_group_generator:
batch = next(dl)
with torch.no_grad():
if "x_surf" in batch:
# combine x and x_surf
Expand Down

0 comments on commit d50849e

Please sign in to comment.