Skip to content

Commit

Permalink
Replace mask_by with _apply_land_sea_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Aug 22, 2023
1 parent 8d6f2e1 commit ca5795d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 43 deletions.
75 changes: 33 additions & 42 deletions e3sm_diags/derivations/acme_new.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
This module defines functions for deriving variables using other variables.
"""
import copy
from collections import OrderedDict
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple

Expand Down Expand Up @@ -99,29 +98,30 @@ def convert_units(var: xr.DataArray, target_units: str): # noqa: C901
return var


def mask_by(input_var, maskvar, low_limit=None, high_limit=None):
"""masks a variable var to be missing except where maskvar>=low_limit and maskvar<=high_limit.
None means to omit the constrint, i.e. low_limit = -infinity or high_limit = infinity.
var is changed and returned; we don't make a new variable.
var and maskvar: dimensioned the same variables.
low_limit and high_limit: scalars.
def _apply_land_sea_mask(
var: xr.DataArray, var_mask: xr.DataArray, lower_limit: float
) -> xr.DataArray:
"""Apply a land or sea mask on the variable.
Parameters
----------
var : xr.DataArray
The variable.
var_mask : xr.DataArray
The variable mask ("LANDFRAC" or "OCNFRAC").
lower_limit : float
Update the mask variable with a lower limit. All values below the
lower limit will be masked.
Returns
-------
xr.DataArray
The masked variable.
"""
var = copy.deepcopy(input_var)
if low_limit is None and high_limit is None:
return var
if low_limit is None and high_limit is not None:
maskvarmask = maskvar > high_limit
elif low_limit is not None and high_limit is None:
maskvarmask = maskvar < low_limit
else:
maskvarmask = (maskvar < low_limit) | (maskvar > high_limit)
cond = var_mask > lower_limit
masked_var = var.where(cond=cond, drop=False)

if var.mask is False:
newmask = maskvarmask
else:
newmask = var.mask | maskvarmask
var.mask = newmask
return var
return masked_var


def qflxconvert_units(var):
Expand Down Expand Up @@ -690,10 +690,10 @@ def cosp_histogram_standardize(cld: "FileVariable"):
(("sst",), rename),
(
("TS", "OCNFRAC"),
lambda ts, ocnfrac: mask_by(
lambda ts, ocnfrac: _apply_land_sea_mask(
convert_units(ts, target_units="degC"),
ocnfrac,
low_limit=0.9,
lower_limit=0.9,
),
),
(("SST",), lambda sst: convert_units(sst, target_units="degC")),
Expand Down Expand Up @@ -999,25 +999,14 @@ def cosp_histogram_standardize(cld: "FileVariable"):
(("rtmt",), rename),
]
),
# 'TREFHT_LAND': OrderedDict([
# (('TREFHT_LAND',), rename),
# (('TREFHT', 'LANDFRAC'), lambda trefht, landfrac: mask_by(
# convert_units(trefht, target_units="K"), landfrac, low_limit=0.65))
# ]),
# 'TREFHT_LAND': OrderedDict([
# (('TREFHT_LAND',), lambda t: convert_units(rename(t), target_units="DegC")),
# (('tas',), lambda t: convert_units(t, target_units="DegC")), #special case for GHCN data provided by Jerry
# (('TREFHT', 'LANDFRAC'), lambda trefht, landfrac: mask_by(
# convert_units(trefht, target_units="DegC"), landfrac, low_limit=0.65))
# ]),
"PRECT_LAND": OrderedDict(
[
(("PRECIP_LAND",), rename),
# 0.5 just to match amwg
(
("PRECC", "PRECL", "LANDFRAC"),
lambda precc, precl, landfrac: mask_by(
prect(precc, precl), landfrac, low_limit=0.5
lambda precc, precl, landfrac: _apply_land_sea_mask(
prect(precc, precl), landfrac, lower_limit=0.5
),
),
]
Expand Down Expand Up @@ -1088,10 +1077,10 @@ def cosp_histogram_standardize(cld: "FileVariable"):
),
(
("TGCLDLWP", "OCNFRAC"),
lambda tgcldlwp, ocnfrac: mask_by(
lambda tgcldlwp, ocnfrac: _apply_land_sea_mask(
convert_units(tgcldlwp, target_units="g/m^2"),
ocnfrac,
low_limit=0.65,
lower_limit=0.65,
),
),
]
Expand All @@ -1104,10 +1093,10 @@ def cosp_histogram_standardize(cld: "FileVariable"):
),
(
("PRECC", "PRECL", "OCNFRAC"),
lambda a, b, ocnfrac: mask_by(
lambda a, b, ocnfrac: _apply_land_sea_mask(
aplusb(a, b, target_units="mm/day"),
ocnfrac,
low_limit=0.65,
lower_limit=0.65,
),
),
]
Expand All @@ -1117,7 +1106,9 @@ def cosp_histogram_standardize(cld: "FileVariable"):
(("PREH2O_OCEAN",), lambda x: convert_units(x, target_units="mm")),
(
("TMQ", "OCNFRAC"),
lambda preh2o, ocnfrac: mask_by(preh2o, ocnfrac, low_limit=0.65),
lambda preh2o, ocnfrac: _apply_land_sea_mask(
preh2o, ocnfrac, lower_limit=0.65
),
),
]
),
Expand Down
2 changes: 1 addition & 1 deletion e3sm_diags/driver/lat_lon_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901
# for each region to a JSON file.
vars_have_z_axis = has_z_axis(dv_test) and has_z_axis(dv_ref)

# TODO: Refactor both conditionals since there logic is similar.
# TODO: Refactor both conditionals since logic is similar.
if not vars_have_z_axis:
for region in regions:
logger.info(f"Selected region: {region}")
Expand Down

0 comments on commit ca5795d

Please sign in to comment.