Skip to content

Commit

Permalink
Qbo wavelet (#836)
Browse files Browse the repository at this point in the history
* Add QBO wavelet calculations

* Update QBO plot for wavelet calculations

Adjust the plot panel to accommodate

* Add comments and docstrings

* Add power level for reference case

* Remove trailing whitespaces

* Fix pre-commit-hooks failures

* Update qbo_driver.py

* Match current space formatting

* Clean up QBO wavelet spectra plots

* Add vertical line to indicate wavelet spectra peak and x-axis label for corresponding periods
* Use square root values for clarity

* Update qbo_plot.py

* plot minor tweak

* more fix

* fix units

---------

Co-authored-by: whannah1 jjbenedict, ChengzhuZhang <zhang40@llnl.gov>
  • Loading branch information
justin-richling and chengzhuzhang authored Sep 17, 2024
1 parent 31e10a3 commit 2b303df
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 8 deletions.
51 changes: 50 additions & 1 deletion e3sm_diags/driver/qbo_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cdutil
import numpy as np
import scipy.fftpack
from scipy.signal import detrend

from e3sm_diags.derivations import default_regions
from e3sm_diags.driver import utils
Expand Down Expand Up @@ -158,6 +159,19 @@ def get_psd_from_deseason(xraw, period_new):
return psd_x_new0, amplitude_new0


def get_psd_from_wavelet(data):
"""
Return power spectral density using a complex Morlet wavelet spectrum of degree 6
"""
deg = 6
period = np.arange(1, 55 + 1)
freq = 1 / period
widths = deg / (2 * np.pi * freq)
cwtmatr = scipy.signal.cwt(data, scipy.signal.morlet2, widths=widths, w=deg)
psd = np.mean(np.square(np.abs(cwtmatr)), axis=1)
return (period, psd)


def run_diag(parameter: QboParameter) -> QboParameter:
variables = parameter.variables
# The region will always be 5S5N
Expand Down Expand Up @@ -216,6 +230,41 @@ def run_diag(parameter: QboParameter) -> QboParameter:
)
ref["period_new"] = period_new

# Diagnostic 4: calculate the Wavelet
# Target vertical level
pow_spec_lev = 20.0

# Find the closest value for power spectral level in the list
# List of test case vertical levels
test_lev_list = list(test["level"])
closest_lev = min(test_lev_list, key=lambda x: abs(x - pow_spec_lev))
closest_index = test_lev_list.index(closest_lev)
# Grab target vertical level
test_data_avg = test["qbo"][:, closest_index]

# List of reference case vertical levels
ref_lev_list = list(ref["level"])
# Find the closest value for power spectral level in the list
closest_lev = min(ref_lev_list, key=lambda x: abs(x - pow_spec_lev))
closest_index = ref_lev_list.index(closest_lev)
# Grab target vertical level
ref_data_avg = ref["qbo"][:, closest_index]

# convert to anomalies
test_data_avg = test_data_avg - test_data_avg.mean()
ref_data_avg = ref_data_avg - ref_data_avg.mean()

# Detrend the data
test_detrended_data = detrend(test_data_avg)
ref_detrended_data = detrend(ref_data_avg)

test["wave_period"], test_wavelet = get_psd_from_wavelet(test_detrended_data)
ref["wave_period"], ref_wavelet = get_psd_from_wavelet(ref_detrended_data)

# Get square root values of wavelet spectra
test["wavelet"] = np.sqrt(test_wavelet)
ref["wavelet"] = np.sqrt(ref_wavelet)

parameter.var_id = variable
parameter.main_title = (
"QBO index, amplitude, and power spectral density for {}".format(variable)
Expand Down Expand Up @@ -263,7 +312,7 @@ def run_diag(parameter: QboParameter) -> QboParameter:
json_dict = test_json
else:
json_dict = ref_json
json.dump(json_dict, outfile)
json.dump(json_dict, outfile, default=str)
# Get the file name that the user has passed in and display that.
json_output_file_name = os.path.join(
utils.general.get_output_dir(parameter.current_set, parameter),
Expand Down
65 changes: 58 additions & 7 deletions e3sm_diags/plot/cartopy/qbo_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@

logger = custom_logger(__name__)

# rect : tuple (left, bottom, width, height)
# All quantities are in fractions of figure width and height.
panel = [
(0.075, 0.70, 0.6, 0.225),
(0.075, 0.425, 0.6, 0.225),
(0.725, 0.425, 0.2, 0.5),
(0.075, 0.075, 0.85, 0.275),
(0.075, 0.75, 0.6, 0.175), # Adjusted height and y position
(0.075, 0.525, 0.6, 0.175), # Adjusted height and y position
(0.725, 0.525, 0.2, 0.4), # Adjusted height and y position
(0.075, 0.285, 0.85, 0.175), # Adjusted height and y position
(0.075, 0.04, 0.85, 0.175), # Adjusted height and y position
]

# Border padding relative to subplot axes for saving individual panels
Expand Down Expand Up @@ -68,6 +71,28 @@ def plot_panel(
)
(p1,) = ax.plot(x["data"], y["data"], "-ok")
(p2,) = ax.plot(x["data2"], y["data2"], "--or")
if n == 3 or n == 4:
# Find the index of the wavelet maximum value
test_ymax_idx = list(y["data"]).index(max(y["data"]))
ref_ymax_idx = list(y["data2"]).index(max(y["data2"]))

# Use the index to get the period value for peak of spectra
test_y_max_xval = list(x["data"])[test_ymax_idx]
ref_y_max_xval = list(x["data2"])[ref_ymax_idx]

# Plot vertical lines for period peaks
ax.axvline(
x=test_y_max_xval,
ymax=max(y["data"]) / y["axis_range"][1],
color="k",
linestyle="-",
)
ax.axvline(
x=ref_y_max_xval,
ymax=max(y["data2"]) / y["axis_range"][1],
color="r",
linestyle="--",
)
plt.grid("on")
ax.legend(
(p1, p2),
Expand All @@ -82,6 +107,11 @@ def plot_panel(
plt.ylim([y["axis_range"][0], y["axis_range"][1]])
plt.yticks(size=label_size)
plt.xscale(x["axis_scale"])
if n == 3 or n == 4:
# Set custom x-axis tick labels to include period corresponding to peak of wavelet spectra
standard_ticks = list(np.arange(x["axis_range"][0], x["axis_range"][1] + 1, 5))
custom_ticks = sorted(standard_ticks + [test_y_max_xval, ref_y_max_xval])
ax.set_xticks(custom_ticks)
plt.xlim([x["axis_range"][0], x["axis_range"][1]])
plt.xticks(size=label_size)

Expand All @@ -91,7 +121,7 @@ def plot_panel(
def plot(parameter, test, ref):
label_size = 14

fig = plt.figure(figsize=(14, 14))
fig = plt.figure(figsize=(14, 18))

months = np.minimum(ref["qbo"].shape[0], test["qbo"].shape[0])
x_test, y_test = np.meshgrid(np.arange(0, months), test["level"])
Expand Down Expand Up @@ -156,9 +186,9 @@ def plot(parameter, test, ref):
title = "QBO Amplitude \n (period = 20-40 months)"
plot_panel(2, fig, "line", label_size, title, x, y)

# Panel 3 (Bottom)
# Panel 3 (Bottom/Top)
x = dict(
axis_range=[0, 50],
axis_range=[5, 50],
axis_scale="linear",
data=test["period_new"],
data_label=test["name"],
Expand All @@ -175,6 +205,27 @@ def plot(parameter, test, ref):
)
title = "QBO Spectral Density (Eq. 18-22 hPa zonal winds)"
plot_panel(3, fig, "line", label_size, title, x, y)

# Panel 4 (Bottom/Bottom)
x = dict(
axis_range=[5, 50],
axis_scale="linear",
data=test["wave_period"],
data_label=test["name"],
data2=ref["wave_period"],
data2_label=ref["name"],
label="Period (months)",
)
y = dict(
axis_range=[-1, 105],
axis_scale="linear",
data=test["wavelet"],
data2=ref["wavelet"],
label="Variance (" + "m\u00b2/s\u00b2" + ")",
)
title = "QBO Wavelet (Eq. 18-22 hPa zonal winds)"
plot_panel(4, fig, "line", label_size, title, x, y)

plt.tight_layout()

# Figure title
Expand Down

0 comments on commit 2b303df

Please sign in to comment.