Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble capability for single-step training #137

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

dkimpara
Copy link
Collaborator

@dkimpara dkimpara commented Dec 20, 2024

  • supports batch_size > 1
  • currently only KCRPS (unbiased CRPS) available
  • train and validation are the same loss, for this mode

new config field:
int trainer.ensemble_size = 1 by default (also set by parser)

high-level code design:

  • era5trainer_v2: torch.repeat_interleave to copy samples to transform batches from (b, ...) to (b * ensemble_size, ...). Then this new tensor is passed into the model
  • loss.py, KCRPSLoss: torch.vmap to vectorize and lift single obs modulus loss to handle ensembles with multiple target obs (batch_size > 1)
  • metrics.py: compute ensemble mean to compute metrics on

Future features (did not put them in this PR so to not clutter the config)

  • different train/validation losses (including ensemble/not-ensemble pairings)
  • different ensemble sizes for train/validation

Testing (on casper)

applications/train.py -c config/test_cesm_ensemble.yml -l 1

@dkimpara dkimpara marked this pull request as ready for review December 20, 2024 18:36
@jsschreck
Copy link
Collaborator

@dkimpara the batch_size > 1 ... is this meaning it will support the new datasets I added, or are you doing it yourself here?

@dkimpara
Copy link
Collaborator Author

dkimpara commented Dec 25, 2024 via email

@jsschreck
Copy link
Collaborator

@dkimpara trainerERA5_v2.py will be deprecated in my PR; we will only need trainerERA5_multistep_grad_accum.py going forward, so you should add the ensemble change to that script as well. Keep what you have b/c I am not going to remove _v2 just yet since other people are still using it.

Copy link
Collaborator

@yingkaisha yingkaisha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finished review. My request is that deterministic training should skip the CRPS code blocks.

# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...)
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0)

Copy link
Collaborator

@yingkaisha yingkaisha Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This place can have an if condition, so the deterministic training routine will not touch this part.

# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...)
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if condition

# calculate ensemble mean, if ensemble_size=1, does nothing
pred = pred.view(y.shape[0], self.ensemble_size, *y.shape[1:]) #b, ensemble, c, t, lat, lon
pred = pred.mean(dim=1)

Copy link
Collaborator

@yingkaisha yingkaisha Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if condition

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah -- will also fail here.

@@ -710,6 +710,9 @@ def credit_main_parser(
"train_batch_size" in conf["trainer"]
), "Training set batch size ('train_batch_size') is missing from onf['trainer']"

if "ensemble_size" not in conf["trainer"]:
conf["trainer"]["ensemble_size"] = 1
Copy link
Collaborator

@yingkaisha yingkaisha Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you have an option to skip the ensemble routine? For those who do not use ensemble training, going through the ensemble code can cause a slow-down.

@@ -28,11 +28,17 @@ def __init__(self, conf, predict_mode=False):
# DO NOT apply these weights during metrics computations, only on the loss during
self.w_var = None

self.ensemble_size = conf["trainer"]["ensemble_size"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I do not use ensemble_size this will fail here. Prob better to have something like "if ensemble_size in conf["trainer"] conditional

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants