-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensemble capability for single-step training #137
base: main
Are you sure you want to change the base?
Conversation
@dkimpara the batch_size > 1 ... is this meaning it will support the new datasets I added, or are you doing it yourself here? |
I haven't yet configured it for your datasets. I will do that once your PR
is settled. What I meant was that the ensemble features are compatible with
batches with more than one target
Dhamma Kimpara
PhD Candidate in Computer Science
National Center for Atmospheric Research, and University of Colorado Boulder
Pronouns in use: He/Him/His
…On Wed, Dec 25, 2024, 8:48 AM jsschreck ***@***.***> wrote:
@dkimpara <https://github.com/dkimpara> the batch_size > 1 ... is this
meaning it will support the new datasets I added, or are you doing it
yourself here?
—
Reply to this email directly, view it on GitHub
<#137 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AD7CU77BITFFQ74LMAIHYN32HLHVXAVCNFSM6AAAAABT6B75GKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKNRRHEZTINJTGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@dkimpara trainerERA5_v2.py will be deprecated in my PR; we will only need trainerERA5_multistep_grad_accum.py going forward, so you should add the ensemble change to that script as well. Keep what you have b/c I am not going to remove _v2 just yet since other people are still using it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished review. My request is that deterministic training should skip the CRPS code blocks.
# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...) | ||
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples | ||
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This place can have an if condition, so the deterministic training routine will not touch this part.
# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...) | ||
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples | ||
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if condition
# calculate ensemble mean, if ensemble_size=1, does nothing | ||
pred = pred.view(y.shape[0], self.ensemble_size, *y.shape[1:]) #b, ensemble, c, t, lat, lon | ||
pred = pred.mean(dim=1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah -- will also fail here.
@@ -710,6 +710,9 @@ def credit_main_parser( | |||
"train_batch_size" in conf["trainer"] | |||
), "Training set batch size ('train_batch_size') is missing from onf['trainer']" | |||
|
|||
if "ensemble_size" not in conf["trainer"]: | |||
conf["trainer"]["ensemble_size"] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you have an option to skip the ensemble routine? For those who do not use ensemble training, going through the ensemble code can cause a slow-down.
@@ -28,11 +28,17 @@ def __init__(self, conf, predict_mode=False): | |||
# DO NOT apply these weights during metrics computations, only on the loss during | |||
self.w_var = None | |||
|
|||
self.ensemble_size = conf["trainer"]["ensemble_size"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I do not use ensemble_size this will fail here. Prob better to have something like "if ensemble_size in conf["trainer"] conditional
new config field:
int trainer.ensemble_size = 1
by default (also set by parser)high-level code design:
torch.repeat_interleave
to copy samples to transform batches from (b, ...) to (b * ensemble_size, ...). Then this new tensor is passed into the modeltorch.vmap
to vectorize and lift single obs modulus loss to handle ensembles with multiple target obs (batch_size > 1)Future features (did not put them in this PR so to not clutter the config)
Testing (on casper)
applications/train.py -c config/test_cesm_ensemble.yml -l 1