Skip to content

Commit

Permalink
CDAT Migration Phase 2: Refactor zonal_mean_2d() and `zonal_mean_2d…
Browse files Browse the repository at this point in the history
…_stratosphere()` sets (#774)
  • Loading branch information
tomvothecoder authored Feb 14, 2024
1 parent 6b296b1 commit 57b15fa
Show file tree
Hide file tree
Showing 21 changed files with 2,277 additions and 815 deletions.
2 changes: 1 addition & 1 deletion .vscode/e3sm_diags.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from auxiliary_tools.cdat_regression_testing.base_run_script import run_set

SET_NAME = "zonal_mean_2d"
SET_DIR = "655-zonal-mean-2d"

run_set(SET_NAME, SET_DIR)

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from auxiliary_tools.cdat_regression_testing.base_run_script import run_set

SET_NAME = "zonal_mean_2d_stratosphere"
SET_DIR = "655-zonal-mean-2d-stratosphere"

run_set(SET_NAME, SET_DIR)
6 changes: 3 additions & 3 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,10 @@ def _get_climo_dataset(self, season: str) -> xr.Dataset:
filepath = self._get_climo_filepath(season)
ds = self._open_climo_dataset(filepath)

if self.var in ds.variables:
pass
elif self.var in self.derived_vars_map:
if self.var in self.derived_vars_map:
ds = self._get_dataset_with_derived_climo_var(ds)
elif self.var in ds.data_vars.keys():
pass
else:
raise IOError(
f"Variable '{self.var}' was not in the file '{filepath}', nor was "
Expand Down
22 changes: 14 additions & 8 deletions e3sm_diags/driver/utils/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def has_z_axis(data_var: xr.DataArray) -> bool:
return False


def get_z_axis(data_var: xr.DataArray) -> xr.DataArray:
"""Gets the Z axis coordinates.
def get_z_axis(obj: xr.Dataset | xr.DataArray) -> xr.DataArray:
"""Gets the Z axis coordinates from an xarray object.
Returns True if:
- Data variable has a "Z" axis in the cf-xarray mapping dict
Expand All @@ -133,8 +133,8 @@ def get_z_axis(data_var: xr.DataArray) -> xr.DataArray:
Parameters
----------
data_var : xr.DataArray
The data variable.
obj : xr.Dataset | xr.DataArray
The xarray Dataset or DataArray.
Returns
-------
Expand All @@ -148,18 +148,19 @@ def get_z_axis(data_var: xr.DataArray) -> xr.DataArray:
- https://cdms.readthedocs.io/en/latest/_modules/cdms2/axis.html#AbstractAxis.isLevel
"""
try:
z_coords = xc.get_dim_coords(data_var, axis="Z")
z_coords = xc.get_dim_coords(obj, axis="Z")

return z_coords
except KeyError:
pass

for coord in data_var.coords.values():
for coord in obj.coords.values():
if coord.name in ["lev", "plev", "depth"]:
return coord

raise KeyError(
f"No Z axis coordinate were found in the '{data_var.name}' "
"Make sure the variable has Z axis coordinates"
f"No Z axis coordinates were found in the {type(obj)}. Make sure the "
f"{type(obj)} has Z axis coordinates."
)


Expand Down Expand Up @@ -517,6 +518,7 @@ def _hybrid_to_plevs(

pressure_grid = xc.create_grid(z=z_axis)
pressure_coords = _hybrid_to_pressure(ds, var_key)

# Keep the "axis" and "coordinate" attributes for CF mapping.
with xr.set_options(keep_attrs=True):
result = ds.regridder.vertical(
Expand All @@ -527,6 +529,10 @@ def _hybrid_to_plevs(
target_data=pressure_coords,
)

# Vertical regriding sets the units to "mb", but the original units
# should be preserved.
result[var_key].attrs["units"] = ds[var_key].attrs["units"]

return result


Expand Down
Loading

0 comments on commit 57b15fa

Please sign in to comment.