From 6ce1c6bdd88b2e5fbc70328a972686d2b8609525 Mon Sep 17 00:00:00 2001 From: dkimpara Date: Thu, 19 Dec 2024 16:41:44 -0700 Subject: [PATCH] ensemble capability for batch_size>1 --- config/test_cesm_ensemble.yml | 28 +++++++++++++++------------- credit/loss.py | 22 +++++++++++++++------- credit/metrics.py | 8 +++++++- tests/test_loss.py | 24 ++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 21 deletions(-) create mode 100644 tests/test_loss.py diff --git a/config/test_cesm_ensemble.yml b/config/test_cesm_ensemble.yml index c09bc17..20e089b 100644 --- a/config/test_cesm_ensemble.yml +++ b/config/test_cesm_ensemble.yml @@ -90,7 +90,7 @@ trainer: cpu_offload: False activation_checkpoint: True - load_weights: True + load_weights: False load_optimizer: False load_scaler: False load_sheduler: False @@ -109,12 +109,12 @@ trainer: ensemble_size: 2 long_rollout: True - batches_per_epoch: 1000 - valid_batches_per_epoch: 4 + batches_per_epoch: 1 + valid_batches_per_epoch: 1 stopping_patience: 999 start_epoch: 0 - num_epoch: 50 + num_epoch: 1 reload_epoch: False epochs: &epochs 100 @@ -185,7 +185,7 @@ model: T: 1 skebs: - activate: True + activate: False lmax: None mmax: None freeze_base_model_weights: True @@ -281,12 +281,14 @@ predict: # save_format: "nc" pbs: #derecho - conda: "/glade/u/home/dkimpara/credit-derecho" + conda: "/glade/u/home/dkimpara/credit" project: "NAML0001" - job_name: "wxformer_1h" - walltime: "12:00:00" - nodes: 2 - ncpus: 64 - ngpus: 4 - mem: '480GB' - queue: 'main' + job_name: "test_cesm_ensemble" + walltime: "00:15:00" + nodes: 1 + ncpus: 8 + ngpus: 1 + mem: '32GB' + gpu_type: 'v100' + project: 'NAML0001' + queue: 'casper' diff --git a/credit/loss.py b/credit/loss.py index 1d78180..eb2dc43 100644 --- a/credit/loss.py +++ b/credit/loss.py @@ -211,9 +211,18 @@ class KCRPSLoss(nn.Module): def __init__(self, reduction, biased: bool = False): super().__init__() self.biased = biased - + self.batched_forward = torch.vmap(self.single_sample_forward) + def forward(self, target, pred): - """Forward pass for KCRPS loss + # integer division but will error out next op if there is a remainder + ensemble_size = pred.shape[0] // target.shape[0] + pred.shape[0] % target.shape[0] + pred = pred.view(target.shape[0], ensemble_size, *target.shape[1:]) #b, ensemble, c, t, lat, lon + # apply single_sample_forward to each dim + target = target.unsqueeze(1) + return self.batched_forward(target, pred).squeeze(1) + + def single_sample_forward(self, target, pred): + """Forward pass for KCRPS loss for a single sample Args: prediction (torch.Tensor): Predicted tensor. @@ -225,8 +234,7 @@ def forward(self, target, pred): pred = torch.movedim(pred, 0, -1) return self._kernel_crps_implementation(pred, target, self.biased) - @torch.jit.script - def _kernel_crps_implementation(pred: torch.Tensor, obs: torch.Tensor, biased: bool) -> torch.Tensor: + def _kernel_crps_implementation(self, pred: torch.Tensor, obs: torch.Tensor, biased: bool) -> torch.Tensor: """An O(m log m) implementation of the kernel CRPS formulas""" skill = torch.abs(pred - obs[..., None]).mean(-1) pred, _ = torch.sort(pred) @@ -549,10 +557,10 @@ def __init__(self, conf, validation=False): ) self.validation = validation - if self.validation: - self.loss_fn = nn.L1Loss(reduction="none") - elif conf["loss"]["training_loss"] == "KCRPS": # for ensembles, load same loss for train and valid + if conf["loss"]["training_loss"] == "KCRPS": # for ensembles, load same loss for train and valid self.loss_fn = load_loss(self.training_loss, reduction="none") + elif self.validation: + self.loss_fn = nn.L1Loss(reduction="none") else: self.loss_fn = load_loss(self.training_loss, reduction="none") diff --git a/credit/metrics.py b/credit/metrics.py index 36bb2ae..053cd0a 100644 --- a/credit/metrics.py +++ b/credit/metrics.py @@ -10,7 +10,7 @@ def __init__(self, conf, predict_mode=False): atmos_vars = conf["data"]["variables"] surface_vars = conf["data"]["surface_variables"] diag_vars = conf["data"]["diagnostic_variables"] - + levels = ( conf["model"]["levels"] if "levels" in conf["model"] @@ -28,11 +28,17 @@ def __init__(self, conf, predict_mode=False): # DO NOT apply these weights during metrics computations, only on the loss during self.w_var = None + self.ensemble_size = conf["trainer"]["ensemble_size"] + def __call__(self, pred, y, clim=None, transform=None, forecast_datetime=0): if transform is not None: pred = transform(pred) y = transform(y) + # calculate ensemble mean, if ensemble_size=1, does nothing + pred = pred.view(y.shape[0], self.ensemble_size, *y.shape[1:]) #b, ensemble, c, t, lat, lon + pred = pred.mean(dim=1) + # Get latitude and variable weights w_lat = ( self.w_lat.to(dtype=pred.dtype, device=pred.device) diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 0000000..0137098 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,24 @@ +import os + +import torch + +from credit.loss import KCRPSLoss + +TEST_FILE_DIR = "/".join(os.path.abspath(__file__).split("/")[:-1]) +CONFIG_FILE_DIR = os.path.join( + "/".join(os.path.abspath(__file__).split("/")[:-2]), "config" +) + + +def test_KCRPS(): + loss_fn = KCRPSLoss("none") + batch_size = 2 + ensemble_size = 5 + + target = torch.randn(batch_size, 10, 1, 40, 50) + pred = torch.randn(batch_size * ensemble_size, 10, 1, 40, 50) + + loss = loss_fn(target, pred) + assert not torch.isnan(loss).any() + + \ No newline at end of file