-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
119 lines (97 loc) · 3.22 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import cartopy
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
from tools import t_test
# Configuration for cartopy plots
cartopy.config["data_dir"] = (
"/discover/nobackup/projects/jh_tutorials/JH_examples/JH_datafiles/Cartopy"
)
cartopy.config["pre_existing_data_dir"] = (
"/discover/nobackup/projects/jh_tutorials/JH_examples/JH_datafiles/Cartopy"
)
def plot_training_loss(all_train_losses, EPOCHS):
plt.plot(range(EPOCHS), all_train_losses[0], label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("L1 Loss")
plt.title("Training Loss")
plt.grid()
plt.show()
def global_timeseries_plot(Y_pred, Y_train, Y_test, VARIABLE):
# Gather weights
weights = np.cos(np.deg2rad(Y_pred.lat))
weights.name = "weights"
weights.mean()
# Initialize a flag to track if the label has been added
label_added = False
# Adds traininng data plots
for i in range(len(Y_train)):
if not label_added:
Y_train[i][VARIABLE[0]].weighted(weights).mean(dim=["lat", "lon"]).plot(
color="grey", label="Training Data"
)
label_added = True # Set the flag to True after adding the label
else:
Y_train[i][VARIABLE[0]].weighted(weights).mean(dim=["lat", "lon"]).plot(
color="grey"
)
# Plot model predictions
Y_pred[VARIABLE[0]].weighted(weights).mean(dim=["lat", "lon"]).plot(
color="black", label="Model Predictions"
)
# Plot test output
Y_test[VARIABLE[0]].weighted(weights).mean(dim=["lat", "lon"]).plot(
color="green", label="Test Data (SSP245)"
)
# Formatting
plt.title("Global Average Timeseries")
plt.xlabel("Year")
plt.ylabel("Precipitation (mm/day)")
plt.legend()
plt.grid()
def global_anomaly_plot(Y_pred, Y_test, p_value, VARIABLE):
# Extract data for the average of 2080-2100
average_of_runs = Y_pred
prediction_tsurf = average_of_runs
prediction_tsurf_2080_2100 = prediction_tsurf.sel(year=slice(2080, 2100)).mean(
dim="year"
)[VARIABLE[0]]
validation_tsurf_2080_2100 = (
Y_test[VARIABLE[0]].sel(year=slice(2080, 2100)).mean(dim="year")
)
diff = prediction_tsurf_2080_2100 - validation_tsurf_2080_2100
# Create figure
fig = plt.figure(figsize=(10, 6))
ax = plt.subplot(111, projection=ccrs.Robinson())
# Create custom colormap settings
cmap = plt.cm.coolwarm
vmin, vmax = -1, 1
# Plot with custom colorbar but suppress automatic display
im = diff.plot(
ax=ax,
transform=ccrs.PlateCarree(),
vmin=vmin,
vmax=vmax,
cmap=cmap,
extend="both",
cbar_kwargs={
"label": "kg/m²/s",
"shrink": 0.75,
"extendfrac": 0.05,
"extendrect": False,
"aspect": 30,
"format": "%.1f",
"ticks": [-1, -0.5, 0, 0.5, 1],
},
)
# Add coastlines
ax.coastlines()
# Formatting
plt.title(
"Precipitation Error (Y_hat - Y_test)\nEmulator vs. NASA Global Climate Model\n2080-2100"
)
plt.xlabel("Longitude")
plt.ylabel("Latitude")
# Display the plot once
plt.show()