Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing African easterly wave density plots in TC analysis #851

Merged
merged 7 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions e3sm_diags/driver/tc_analysis_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def run_diag(parameter: TCAnalysisParameter) -> TCAnalysisParameter:
test_data_path,
"aew_hist_{}_{}_{}.nc".format(test_name, test_start_yr, test_end_yr),
)
test_aew_hist = cdms2.open(test_aew_file)(
"density", lat=(0, 35, "ccb"), lon=(-180, 0, "ccb"), squeeze=1
)
test_aew_hist = cdms2.open(test_aew_file)("density", squeeze=1)

test_data = collections.OrderedDict()
ref_data = collections.OrderedDict()
Expand Down Expand Up @@ -134,9 +132,8 @@ def run_diag(parameter: TCAnalysisParameter) -> TCAnalysisParameter:
"density", lat=(-60, 60, "ccb"), squeeze=1
)
ref_aew_file = os.path.join(reference_data_path, "aew_hist_ERA5_2010_2014.nc")
ref_aew_hist = cdms2.open(ref_aew_file)(
"density", lat=(0, 35, "ccb"), lon=(180, 360, "ccb"), squeeze=1
)
ref_aew_hist = cdms2.open(ref_aew_file)("density", squeeze=1)

ref_data["cyclone_density"] = ref_cyclones_hist
ref_data["cyclone_num_years"] = 40 # type: ignore
ref_data["aew_density"] = ref_aew_hist
Expand All @@ -163,13 +160,40 @@ def generate_tc_metrics_from_te_stitch_file(te_stitch_file: str) -> Dict[str, An
"""
logger.info("\nGenerating TC Metrics from TE Stitch Files")
logger.info("============================================")
if not os.path.exists(te_stitch_file):
raise FileNotFoundError(f"The file {te_stitch_file} does not exist.")

with open(te_stitch_file) as f:
lines = f.readlines()
lines_orig = f.readlines()

if not lines_orig:
raise ValueError(f"The file {te_stitch_file} is empty.")

line_ind = []
data_start_year = int(te_stitch_file.split(".")[-2].split("_")[-2])
data_end_year = int(te_stitch_file.split(".")[-2].split("_")[-1])
for i in range(0, np.size(lines_orig)):
if lines_orig[i][0] == "s":
year = int(lines_orig[i].split("\t")[2])

if year <= data_end_year:
line_ind.append(i)

# Remove excessive time points cross year bounds from 6 hourly data
end_ind = line_ind[-1]
lines = lines_orig[0:end_ind]
Comment on lines +172 to +184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion, I would extract this block of logic into a private function called _remove_time_crossing_bounds() or something to make the main run_diags() cleaner.

However, this can be done in the refactored codebase too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion. I considered to have a function instead as well. I think we can update the refactored code base when bringing more fix and enhancement to tc set.


# Calculate number of storms and max length
num_storms, max_len = _calc_num_storms_and_max_len(lines)
# Parse variables from TE stitch file
te_stitch_vars = _get_vars_from_te_stitch(lines, max_len, num_storms)
# Add year info
te_stitch_vars["year_start"] = data_start_year
te_stitch_vars["year_end"] = data_end_year
te_stitch_vars["num_years"] = data_end_year - data_start_year + 1
Comment on lines +190 to +193
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, you moved the year info out of _get_vars_from_te_stitch() because they are now different after removing the excessive time points that cross year bounds. Is this right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The start and end years were previously based on data from te_stitches files. It worked fine earlier because in the simulated data, there are always TCs each year. It became problematic when testing v3 datasets, years with no TCs get skipped. The new code now use start and end included in file name, which represent the actually year range being assessed.

logger.info(
f"TE Start Year: {te_stitch_vars['year_start']}, TE End Year: {te_stitch_vars['year_end']}, Total Years: {te_stitch_vars['num_years']}"
)

# Use E3SM land-sea mask
mask_path = os.path.join(e3sm_diags.INSTALL_PATH, "acme_ne30_ocean_land_mask.nc")
Expand Down Expand Up @@ -246,15 +270,11 @@ def _get_vars_from_te_stitch(
vars_dict = {k: np.empty((max_len, num_storms)) * np.nan for k in keys}

index = 0
year_start = int(lines[0].split("\t")[2])
year_end = year_start

for line in lines:
line_split = line.split("\t")
if line[0] == "s":
index = index + 1
year = int(line_split[2])
year_end = max(year, year_start)
k = 0
else:
k = k + 1
Expand All @@ -265,13 +285,6 @@ def _get_vars_from_te_stitch(
vars_dict["yearmc"][k - 1, index - 1] = float(line_split[6])
vars_dict["monthmc"][k - 1, index - 1] = float(line_split[7])

vars_dict["year_start"] = year_start # type: ignore
vars_dict["year_end"] = year_end # type: ignore
vars_dict["num_years"] = year_end - year_start + 1 # type: ignore
logger.info(
f"TE Start Year: {vars_dict['year_start']}, TE End Year: {vars_dict['year_end']}, Total Years: {vars_dict['num_years']}"
)

return vars_dict


Expand Down
6 changes: 0 additions & 6 deletions tests/e3sm_diags/drivers/test_tc_analysis_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ def test_correct_output(self):
"vsmc": np.array([[1.94, np.nan]]),
"yearmc": np.array([[1, np.nan]]),
"monthmc": np.array([[1, np.nan]]),
"year_start": 90,
"year_end": 90,
"num_years": 1,
}
result = _get_vars_from_te_stitch(lines, max_len, num_storms)

Expand All @@ -70,9 +67,6 @@ def test_correct_output(self):
np.array_equal(result["vsmc"], expected["vsmc"])
np.array_equal(result["yearmc"], expected["yearmc"])
np.array_equal(result["monthmc"], expected["monthmc"])
self.assertEqual(result["year_start"], expected["year_start"])
self.assertEqual(result["year_end"], expected["year_end"])
self.assertEqual(result["num_years"], expected["num_years"])


class TestDeriveMetricsPerBasin(TestCase):
Expand Down
Loading