diff --git a/credit/postblock.py b/credit/postblock.py index b475b73..e380b2f 100644 --- a/credit/postblock.py +++ b/credit/postblock.py @@ -715,11 +715,20 @@ def cycle_pattern(self, spec_coef): new_coef = (1.0 - self.alpha) * spec_coef + self.g_n * torch.sqrt(self.alpha) * noise # (lmax, mmax) return new_coef - def initialize_mass_calc(self): - a_vals = self.level_info["a_model"].sel(level=self.level_list).to_numpy() + def initialize_hya_b_era5(self): + a_vals = self.level_info["a_model"].sel(level=self.level_list).to_numpy() / 100 b_vals = self.level_info["b_model"].sel(level=self.level_list).to_numpy() + return a_vals, b_vals + + def initialize_hya_b_cesm(self): + a_vals = self.level_info["hyam"].to_numpy() + b_vals = self.level_info["hybm"].to_numpy() + return a_vals, b_vals + + def initialize_mass_calc(self): + a_vals, b_vals = self.initialize_hya_b_cesm(self) self.register_buffer('a_tensor', - torch.from_numpy(a_vals).view(1, self.levels, 1, 1, 1) / 100, + torch.from_numpy(a_vals).view(1, self.levels, 1, 1, 1), persistent=False) self.register_buffer('b_tensor', torch.from_numpy(b_vals).view(1, self.levels, 1, 1, 1), @@ -728,7 +737,8 @@ def initialize_mass_calc(self): torch.from_numpy(self.surface_area).view(1, 1, 1, self.nlat, 1), persistent=False) self.compute_plev_quantities = compute_pressure_on_mlevs(a_vals=self.a_tensor, b_vals=self.b_tensor, plev_dim=1) - + + def calculate_mass(self, sp): # 1 / g * A * integral(dp) [thickness] return (1.0 / GRAVITY