-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathfates_xarray_funcs.py
181 lines (145 loc) · 7.29 KB
/
fates_xarray_funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""functions for using fates and xarray"""
import xarray as xr
import numpy as np
def _get_check_dim(dim_short, dataset):
"""Get dim name from short code and ensure it's on Dataset
Probably only useful internally to this module; see deduplex().
Args:
dim_short (string): The short name of the dimension. E.g., "age"
dataset (xarray Dataset): The Dataset we expect to include the dimension
Raises:
NameError: Dimension not found on Dataset
Returns:
string: The long name of the dimension. E.g., "fates_levage"
"""
dim = "fates_lev" + dim_short
if dim not in dataset.dims:
raise NameError(f"Dimension {dim} not present in Dataset with dims {dataset.dims}")
return dim
def _get_dim_combined(dim1_short, dim2_short):
"""Get duplexed dimension name, given two short names
Args:
dim1_short (string): Short name of first duplexed dimension. E.g., when de-duplexing
fates_levscpf, dim1_short=scls.
dim2_short (string): Short name of second duplexed dimension. E.g., when de-duplexing
fates_levscpf, dim2_short=pft.
Returns:
string: Duplexed dimension name
"""
dim_combined = "fates_lev" + dim1_short + dim2_short
# Handle further-shortened dim names
if dim_combined == "fates_levcanleaf":
dim_combined = "fates_levcnlf"
elif dim_combined == "fates_levcanpft":
dim_combined = "fates_levcapf"
elif dim_combined == "fates_levcdamscls":
dim_combined = "fates_levcdsc"
elif dim_combined == "fates_levsclsage":
dim_combined = "fates_levscag"
elif dim_combined == "fates_levsclspft":
dim_combined = "fates_levscpf"
return dim_combined
def deduplex(dataset, this_var, dim1_short, dim2_short, preserve_order=True):
"""Reshape a duplexed FATES dimension into its constituent dimensions
For example, given a variable with dimensions
(time, fates_levagepft, lat, lon),
this will return a DataArray with dimensions
(time, fates_levage, fates_levpft, lat, lon)
Or with reorder=False:
(time, fates_levpft, lat, lon, fates_levage).
Args:
dataset (xarray Dataset): Dataset containing the variable with dimension to de-duplex
this_var (string or xarray DataArray): (Name of) variable with dimension to de-duplex
dim1_short (string): Short name of first duplexed dimension. E.g., when de-duplexing
fates_levagepft, dim1_short=age.
dim2_short (string): Short name of second duplexed dimension. E.g., when de-duplexing
fates_levagepft, dim2_short=pft.
preserve_order (bool, optional): Preserve order of dimensions of input DataArray? Defaults
to True. Might be faster if False. See examples above.
Raises:
RuntimeError: dim1_short == dim2_short (not yet handled)
TypeError: Incorrect type of this_var
NameError: Dimension not found on Dataset
Returns:
xarray DataArray: De-duplexed variable
"""
if dim1_short == dim2_short:
raise RuntimeError("deduplex() can't currently handle dim1_short==dim2_short")
# Get DataArray
if isinstance(this_var, xr.DataArray):
da_in = this_var
elif isinstance(this_var, str):
da_in = dataset[this_var]
else:
raise TypeError("this_var must be either string or DataArray, not " + type(this_var))
# Get combined dim name
dim_combined = _get_dim_combined(dim1_short, dim2_short)
if dim_combined not in da_in.dims:
raise NameError(f"Dimension {dim_combined} not present in DataArray with dims {da_in.dims}")
# Get individual dim names
dim1 = _get_check_dim(dim1_short, dataset)
dim2 = _get_check_dim(dim2_short, dataset)
# Split multiplexed dimension into its components
n_dim1 = len(dataset[dim1])
da_out = (
da_in.rolling({dim_combined: n_dim1}, center=False)
.construct(dim1)
.isel({dim_combined: slice(n_dim1 - 1, None, n_dim1)})
.rename({dim_combined: dim2})
.assign_coords({dim1: dataset[dim1]})
.assign_coords({dim2: dataset[dim2]})
)
# Reorder so that the split dimensions are together and in the expected order
if preserve_order:
new_dim_order = []
for dim in da_out.dims:
if dim == dim2:
new_dim_order.append(dim1)
if dim != dim1:
new_dim_order.append(dim)
da_out = da_out.transpose(*new_dim_order)
return da_out
def agefuel_to_age_by_fuel(agefuel_var, dataset):
"""function to reshape a fates multiplexed age and fuel size indexed variable to one indexed by age and fuel size
first argument should be an xarray DataArray that has the FATES AGEFUEL dimension
second argument should be an xarray Dataset that has the FATES FUEL dimension
(possibly the dataset encompassing the dataarray being transformed)
returns an Xarray DataArray with the size and pft dimensions disentangled"""
return deduplex(dataset, agefuel_var, "age", "fuel", preserve_order=False)
def scpf_to_scls_by_pft(scpf_var, dataset):
"""function to reshape a fates multiplexed size and pft-indexed variable to one indexed by size class and pft
first argument should be an xarray DataArray that has the FATES SCPF dimension
second argument should be an xarray Dataset that has the FATES SCLS dimension
(possibly the dataset encompassing the dataarray being transformed)
returns an Xarray DataArray with the size and pft dimensions disentangled"""
return deduplex(dataset, scpf_var, "scls", "pft", preserve_order=False)
def scag_to_scls_by_age(scag_var, dataset):
"""function to reshape a fates multiplexed size and pft-indexed variable to one indexed by size class and pft
first argument should be an xarray DataArray that has the FATES SCAG dimension
second argument should be an xarray Dataset that has the FATES age dimension
(possibly the dataset encompassing the dataarray being transformed) returns an Xarray DataArray with the size and age dimensions disentangled"""
return deduplex(dataset, scag_var, "scls", "age", preserve_order=False)
def monthly_to_annual(array):
"""calculate annual mena from monthly data, using unequal month lengths fros noleap calendar.
originally written by Keith Lindsay."""
mon_day = xr.DataArray(
np.array([31.0, 28.0, 31.0, 30.0, 31.0, 30.0, 31.0, 31.0, 30.0, 31.0, 30.0, 31.0]),
dims=["month"],
)
mon_wgt = mon_day / mon_day.sum()
return (
array.rolling(time=12, center=False) # rolling
.construct("month") # construct the array
.isel(
time=slice(11, None, 12)
) # slice so that the first element is [1..12], second is [13..24]
.dot(mon_wgt, dims=["month"])
)
def monthly_to_month_by_year(array):
"""go from monthly data to month x year data (for calculating climatologies, etc"""
return (
array.rolling(time=12, center=False) # rolling
.construct("month") # construct the array
.isel(time=slice(11, None, 12))
.rename({"time": "year"})
)