-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ocean data regridding to Gaussian grid functionality
- Loading branch information
1 parent
939a6b6
commit 679e33a
Showing
5 changed files
with
369 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Regrid | ||
A convenience wrapper on top of xesmf library that regrids tripolar grid ocean data to a gaussian grid. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Regrid: | ||
input_path: . | ||
output_path: . | ||
rotation_file: ./ocn_rotation_mx025.nc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Tools for working with Gaussian grids.""" | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import functools | ||
|
||
import numpy as np | ||
import numpy.linalg as la | ||
from numpy.polynomial.legendre import legcompanion, legder, legval | ||
|
||
|
||
def __single_arg_fast_cache(func): | ||
"""Caching decorator for functions of one argument.""" | ||
|
||
class CachingDict(dict): | ||
def __missing__(self, key): | ||
result = self[key] = func(key) | ||
return result | ||
|
||
@functools.wraps(func) | ||
def __getitem__(self, *args, **kwargs): | ||
return super(CachingDict, self).__getitem__(*args, **kwargs) | ||
|
||
return CachingDict().__getitem__ | ||
|
||
|
||
@__single_arg_fast_cache | ||
def gaussian_latitudes(n): | ||
"""Construct latitudes and latitude bounds for a Gaussian grid. | ||
Args: | ||
* n: | ||
The Gaussian grid number (half the number of latitudes in the | ||
grid. | ||
Returns: | ||
A 2-tuple where the first element is a length `n` array of | ||
latitudes (in degrees) and the second element is an `(n, 2)` | ||
array of bounds. | ||
""" | ||
if abs(int(n)) != n: | ||
raise ValueError("n must be a non-negative integer") | ||
nlat = 2 * n | ||
# Create the coefficients of the Legendre polynomial and construct the | ||
# companion matrix: | ||
cs = np.array([0] * nlat + [1], dtype=int) | ||
cm = legcompanion(cs) | ||
# Compute the eigenvalues of the companion matrix (the roots of the | ||
# Legendre polynomial) taking advantage of the fact that the matrix is | ||
# symmetric: | ||
roots = la.eigvalsh(cm) | ||
roots.sort() | ||
# Improve the roots by one application of Newton's method, using the | ||
# solved root as the initial guess: | ||
fx = legval(roots, cs) | ||
fpx = legval(roots, legder(cs)) | ||
roots -= fx / fpx | ||
# The roots should exhibit symmetry, but with a sign change, so make sure | ||
# this is the case: | ||
roots = (roots - roots[::-1]) / 2.0 | ||
# Compute the Gaussian weights for each interval: | ||
fm = legval(roots, cs[1:]) | ||
fm /= np.abs(fm).max() | ||
fpx /= np.abs(fpx).max() | ||
weights = 1.0 / (fm * fpx) | ||
# Weights should be symmetric and sum to two (unit weighting over the | ||
# interval [-1, 1]): | ||
weights = (weights + weights[::-1]) / 2.0 | ||
weights *= 2.0 / weights.sum() | ||
# Calculate the bounds from the weights, still on the interval [-1, 1]: | ||
bounds1d = np.empty([nlat + 1]) | ||
bounds1d[0] = -1 | ||
bounds1d[1:-1] = -1 + weights[:-1].cumsum() | ||
bounds1d[-1] = 1 | ||
# Convert the bounds to degrees of latitude on [-90, 90]: | ||
bounds1d = np.rad2deg(np.arcsin(bounds1d)) | ||
bounds2d = np.empty([nlat, 2]) | ||
bounds2d[:, 0] = bounds1d[:-1] | ||
bounds2d[:, 1] = bounds1d[1:] | ||
# Convert the roots from the interval [-1, 1] to latitude values on the | ||
# interval [-90, 90] degrees: | ||
latitudes = np.rad2deg(np.arcsin(roots)) | ||
return latitudes, bounds2d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
import glob | ||
import sys | ||
import os | ||
import yaml | ||
import xarray as xr | ||
import numpy as np | ||
import xesmf as xe | ||
from gaussian_grid import gaussian_latitudes | ||
|
||
|
||
class Regrid: | ||
""" | ||
Regrid ocean dataset that is on a tripolar grid to a Gaussian grid | ||
Required fields in config: | ||
input_path (str): path where ocean netcdf files (ocn_*.nc) are located | ||
output_path (str): path where regridded data and interpolation weights are stored | ||
rotation_file (str): path to file containing rotation fields "sin_rot" and "cos_rot" | ||
Usage examples | ||
Construct Regrid object, specifying output nlat x nlon and optionally config fiel | ||
>>> rg = Regrid(180, 360, config_filename = "config-regrid-ocean.yaml") | ||
Compile list of files to regrid | ||
>>> files = glob.glob(f"./input/ocn_2016_08_02_??.nc") | ||
>>> files.sort() | ||
Create regridder to compute interpolation weights | ||
>>> rg.create_regridder(files[0]) | ||
Regrid all files using weights computed above | ||
>>> for file in files: | ||
>>> rg.regrid(file) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
nlat: int, | ||
nlon: int, | ||
interp_method: str = "bilinear", | ||
config_filename: str = "config-regrid.yaml", | ||
): | ||
super(Regrid, self).__init__() | ||
name = self.__class__.__name__ | ||
|
||
# read configuration from yaml file | ||
with open(config_filename, "r") as f: | ||
contents = yaml.safe_load(f) | ||
self.config = contents[name] | ||
|
||
# specify an output resolution | ||
self.nlon_o = nlon | ||
self.nlat_o = nlat | ||
self.interp_method = interp_method | ||
|
||
def compute_gaussian_grid(self, nlat, nlon): | ||
"""Compute gaussian grid latitudes and longitudes""" | ||
latitudes, _ = gaussian_latitudes(nlat // 2) | ||
longitudes = np.linspace(0, 360, nlon, endpoint=False) | ||
return latitudes, longitudes | ||
|
||
def compute_latlon_grid(self, nlat, nlon): | ||
"""Compute regular latlong grid coordinates""" | ||
latitudes = np.linspace(-90, 90, nlat, endpoint=False) | ||
longitudes = np.linspace(0, 360, nlon, endpoint=False) | ||
return latitudes, longitudes | ||
|
||
def create_regridder(self, file, compute_grid=None): | ||
"""Create regridding instances""" | ||
if compute_grid is None: | ||
compute_grid = self.compute_gaussian_grid | ||
self.input_path = self.config["input_path"] | ||
self.rotation_file = self.config["rotation_file"] | ||
self.output_path = self.config["output_path"] | ||
|
||
# open input dataset with first file | ||
ds_in = xr.open_dataset(file) | ||
nlon_i = ds_in.sizes["yh"] | ||
nlat_i = ds_in.sizes["xh"] | ||
self.ires = f"weights-{nlon_i}x{nlat_i}" | ||
self.ores = f"{self.nlon_o}x{self.nlat_o}" | ||
ds_in_t = ds_in.rename({"xh": "lon", "yh": "lat"}) | ||
ds_in_u = ds_in.rename({"xq": "lon", "yh": "lat"}) | ||
ds_in_v = ds_in.rename({"xh": "lon", "yq": "lat"}) | ||
|
||
# open rotation dataset | ||
file_rot = f"{self.rotation_file}" | ||
ds_rot = xr.open_dataset(file_rot) | ||
ds_rot = ds_rot[["cos_rot", "sin_rot"]] | ||
self.ds_rot = ds_rot.rename({"xh": "lon", "yh": "lat"}) | ||
|
||
# create output dataset | ||
self.lat1d, self.lon1d = compute_grid(self.nlat_o, self.nlon_o) | ||
lons, lats = np.meshgrid(self.lon1d, self.lat1d) | ||
da_out_lons = xr.DataArray(lons, dims=["nx", "ny"]) | ||
da_out_lats = xr.DataArray(lats, dims=["nx", "ny"]) | ||
ds_out_lons = da_out_lons.to_dataset(name="lon") | ||
ds_out_lats = da_out_lats.to_dataset(name="lat") | ||
grid_out = xr.merge([ds_out_lons, ds_out_lats]) | ||
|
||
# interpolation weights files | ||
wgtsfile_t_to_t = ( | ||
f"{self.output_path}/{self.ires}.Ct.{self.ores}.Ct.bilinear.nc" | ||
) | ||
wgtsfile_u_to_t = ( | ||
f"{self.output_path}/{self.ires}.Cu.{self.ires}.Ct.bilinear.nc" | ||
) | ||
wgtsfile_v_to_t = ( | ||
f"{self.output_path}/{self.ires}.Cv.{self.ires}.Ct.bilinear.nc" | ||
) | ||
|
||
# define regridding instances | ||
reuse = os.path.exists(wgtsfile_t_to_t) | ||
self.rg_tt = xe.Regridder( | ||
ds_in_t, | ||
grid_out, | ||
self.interp_method, | ||
periodic=True, | ||
reuse_weights=reuse, | ||
filename=wgtsfile_t_to_t, | ||
) | ||
reuse = os.path.exists(wgtsfile_u_to_t) | ||
self.rg_ut = xe.Regridder( | ||
ds_in_u, | ||
ds_in_t, | ||
self.interp_method, | ||
periodic=True, | ||
reuse_weights=reuse, | ||
filename=wgtsfile_u_to_t, | ||
) | ||
reuse = os.path.exists(wgtsfile_v_to_t) | ||
self.rg_vt = xe.Regridder( | ||
ds_in_v, | ||
ds_in_t, | ||
self.interp_method, | ||
periodic=True, | ||
reuse_weights=reuse, | ||
filename=wgtsfile_v_to_t, | ||
) | ||
|
||
def regrid(self, file): | ||
"""Regrid a single ocean file""" | ||
print(f"Regridding file: {file}") | ||
dtg = file[-16:-3] | ||
ds_in = xr.open_dataset(file) | ||
ds_out = [] | ||
|
||
for var in list(ds_in.keys()): | ||
if len(ds_in[var].coords) > 2: | ||
coords = ds_in[var].coords.to_index() | ||
|
||
# choose regridding type | ||
if coords.names[0] != "time": | ||
raise ValueError("First coordinate should be time") | ||
else: | ||
variable_map = { | ||
"SSU": ("SSV", "U"), | ||
"SSV": (None, "skip"), | ||
"uo": ("vo", "U"), | ||
"vo": (None, "skip"), | ||
"taux": ("tauy", "U"), | ||
"tauy": (None, "skip"), | ||
} | ||
if var in variable_map.keys(): | ||
var2, pos = variable_map[var] | ||
else: | ||
var2, pos = (None, "T") | ||
|
||
# 3-dimensional data | ||
is_3d = False | ||
if coords.names[1] == "z_l" or coords.names[1] == "zl": | ||
dims = ["time", "lev", "lat", "lon"] | ||
is_3d = True | ||
else: | ||
dims = ["time", "lat", "lon"] | ||
|
||
if pos == "T": | ||
interp_out = self.rg_tt(ds_in[var].values) | ||
da_out = xr.DataArray(interp_out, dims=dims) | ||
da_out.attrs["long_name"] = ds_in[var].long_name | ||
da_out.attrs["units"] = ds_in[var].units | ||
ds_out.append(da_out.to_dataset(name=var)) | ||
|
||
if pos == "U": | ||
# interplate u and v to t-point, then rotate currents/winds to | ||
# earth relative before interpolation | ||
|
||
# interpolate to t-points | ||
interp_u = self.rg_ut(ds_in[var].values) | ||
interp_v = self.rg_ut(ds_in[var2].values) | ||
|
||
# rotate to earth-relative | ||
if is_3d: | ||
urot = np.zeros(np.shape(interp_u)[1:]) | ||
vrot = np.zeros(np.shape(interp_u)[1:]) | ||
for k in range(np.shape(interp_u)[1]): | ||
urot[k] = ( | ||
interp_u[0, k, :, :] * self.ds_rot.cos_rot | ||
+ interp_v[0, k, :, :] * self.ds_rot.sin_rot | ||
) | ||
vrot[k] = ( | ||
interp_v[0, k, :, :] * self.ds_rot.cos_rot | ||
- interp_u[0, k, :, :] * self.ds_rot.sin_rot | ||
) | ||
else: | ||
urot = ( | ||
interp_u[0, :, :] * self.ds_rot.cos_rot | ||
+ interp_v[0, :, :] * self.ds_rot.sin_rot | ||
) | ||
vrot = ( | ||
interp_v[0, :, :] * self.ds_rot.cos_rot | ||
- interp_u[0, :, :] * self.ds_rot.sin_rot | ||
) | ||
|
||
# interoplate | ||
uinterp_out = self.rg_tt(urot) | ||
vinterp_out = self.rg_tt(vrot) | ||
da_out = xr.DataArray( | ||
np.expand_dims(uinterp_out, 0), | ||
dims=dims, | ||
) | ||
da_out.attrs["long_name"] = ds_in[var].long_name | ||
da_out.attrs["units"] = ds_in[var].units | ||
ds_out.append(da_out.to_dataset(name=var)) | ||
da_out = xr.DataArray( | ||
np.expand_dims(vinterp_out, 0), | ||
dims=dims, | ||
) | ||
da_out.attrs["long_name"] = ds_in[var2].long_name | ||
da_out.attrs["units"] = ds_in[var2].units | ||
ds_out.append(da_out.to_dataset(name=var2)) | ||
|
||
ds_out = xr.merge(ds_out) | ||
ds_out = ds_out.assign_coords(lon=("lon", self.lon1d)) | ||
ds_out = ds_out.assign_coords(lat=("lat", self.lat1d)) | ||
ds_out = ds_out.assign_coords(lev=("lev", ds_in.z_l.values)) | ||
ds_out = ds_out.assign_coords(time=("time", ds_in.time.values)) | ||
ds_out["lon"].attrs["units"] = "degrees_east" | ||
ds_out["lon"].attrs["axis"] = "X" | ||
ds_out["lon"].attrs["standard_name"] = "longitude" | ||
ds_out["lat"].attrs["units"] = "degrees_north" | ||
ds_out["lat"].attrs["axis"] = "Y" | ||
ds_out["lat"].attrs["standard_name"] = "latitude" | ||
ds_out["lev"].attrs["units"] = "meters" | ||
ds_out["lev"].attrs["positive"] = "down" | ||
ds_out["lev"].attrs["axis"] = "Z" | ||
ds_out.to_netcdf(f"{self.output_path}/ocn_{dtg}_{self.ores}.nc") | ||
ds_out.close() | ||
ds_in.close() | ||
del ds_out | ||
del ds_in |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
name: regrid | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.11 | ||
# Basics | ||
- numpy | ||
- scipy | ||
- matplotlib | ||
# xarray et al | ||
- xarray | ||
- netCDF4 | ||
- h5netcdf | ||
- bottleneck | ||
- dask[complete] | ||
- zarr | ||
- cftime | ||
# regrid | ||
- xesmf | ||
# other plotting and interaction | ||
- jupyterlab | ||
- cartopy | ||
- shapely | ||
# filesystem | ||
- fsspec | ||
- s3fs | ||
- gcsfs |