Skip to content

Commit

Permalink
Add ocean data regridding to Gaussian grid functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
danielabdi-noaa committed Dec 15, 2023
1 parent 939a6b6 commit 679e33a
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 0 deletions.
2 changes: 2 additions & 0 deletions regrid/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Regrid
A convenience wrapper on top of xesmf library that regrids tripolar grid ocean data to a gaussian grid.
5 changes: 5 additions & 0 deletions regrid/config-regrid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Regrid:
input_path: .
output_path: .
rotation_file: ./ocn_rotation_mx025.nc

82 changes: 82 additions & 0 deletions regrid/gaussian_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tools for working with Gaussian grids."""
from __future__ import absolute_import, division, print_function

import functools

import numpy as np
import numpy.linalg as la
from numpy.polynomial.legendre import legcompanion, legder, legval


def __single_arg_fast_cache(func):
"""Caching decorator for functions of one argument."""

class CachingDict(dict):
def __missing__(self, key):
result = self[key] = func(key)
return result

@functools.wraps(func)
def __getitem__(self, *args, **kwargs):
return super(CachingDict, self).__getitem__(*args, **kwargs)

return CachingDict().__getitem__


@__single_arg_fast_cache
def gaussian_latitudes(n):
"""Construct latitudes and latitude bounds for a Gaussian grid.
Args:
* n:
The Gaussian grid number (half the number of latitudes in the
grid.
Returns:
A 2-tuple where the first element is a length `n` array of
latitudes (in degrees) and the second element is an `(n, 2)`
array of bounds.
"""
if abs(int(n)) != n:
raise ValueError("n must be a non-negative integer")
nlat = 2 * n
# Create the coefficients of the Legendre polynomial and construct the
# companion matrix:
cs = np.array([0] * nlat + [1], dtype=int)
cm = legcompanion(cs)
# Compute the eigenvalues of the companion matrix (the roots of the
# Legendre polynomial) taking advantage of the fact that the matrix is
# symmetric:
roots = la.eigvalsh(cm)
roots.sort()
# Improve the roots by one application of Newton's method, using the
# solved root as the initial guess:
fx = legval(roots, cs)
fpx = legval(roots, legder(cs))
roots -= fx / fpx
# The roots should exhibit symmetry, but with a sign change, so make sure
# this is the case:
roots = (roots - roots[::-1]) / 2.0
# Compute the Gaussian weights for each interval:
fm = legval(roots, cs[1:])
fm /= np.abs(fm).max()
fpx /= np.abs(fpx).max()
weights = 1.0 / (fm * fpx)
# Weights should be symmetric and sum to two (unit weighting over the
# interval [-1, 1]):
weights = (weights + weights[::-1]) / 2.0
weights *= 2.0 / weights.sum()
# Calculate the bounds from the weights, still on the interval [-1, 1]:
bounds1d = np.empty([nlat + 1])
bounds1d[0] = -1
bounds1d[1:-1] = -1 + weights[:-1].cumsum()
bounds1d[-1] = 1
# Convert the bounds to degrees of latitude on [-90, 90]:
bounds1d = np.rad2deg(np.arcsin(bounds1d))
bounds2d = np.empty([nlat, 2])
bounds2d[:, 0] = bounds1d[:-1]
bounds2d[:, 1] = bounds1d[1:]
# Convert the roots from the interval [-1, 1] to latitude values on the
# interval [-90, 90] degrees:
latitudes = np.rad2deg(np.arcsin(roots))
return latitudes, bounds2d
253 changes: 253 additions & 0 deletions regrid/regrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import glob
import sys
import os
import yaml
import xarray as xr
import numpy as np
import xesmf as xe
from gaussian_grid import gaussian_latitudes


class Regrid:
"""
Regrid ocean dataset that is on a tripolar grid to a Gaussian grid
Required fields in config:
input_path (str): path where ocean netcdf files (ocn_*.nc) are located
output_path (str): path where regridded data and interpolation weights are stored
rotation_file (str): path to file containing rotation fields "sin_rot" and "cos_rot"
Usage examples
Construct Regrid object, specifying output nlat x nlon and optionally config fiel
>>> rg = Regrid(180, 360, config_filename = "config-regrid-ocean.yaml")
Compile list of files to regrid
>>> files = glob.glob(f"./input/ocn_2016_08_02_??.nc")
>>> files.sort()
Create regridder to compute interpolation weights
>>> rg.create_regridder(files[0])
Regrid all files using weights computed above
>>> for file in files:
>>> rg.regrid(file)
"""

def __init__(
self,
nlat: int,
nlon: int,
interp_method: str = "bilinear",
config_filename: str = "config-regrid.yaml",
):
super(Regrid, self).__init__()
name = self.__class__.__name__

# read configuration from yaml file
with open(config_filename, "r") as f:
contents = yaml.safe_load(f)
self.config = contents[name]

# specify an output resolution
self.nlon_o = nlon
self.nlat_o = nlat
self.interp_method = interp_method

def compute_gaussian_grid(self, nlat, nlon):
"""Compute gaussian grid latitudes and longitudes"""
latitudes, _ = gaussian_latitudes(nlat // 2)
longitudes = np.linspace(0, 360, nlon, endpoint=False)
return latitudes, longitudes

def compute_latlon_grid(self, nlat, nlon):
"""Compute regular latlong grid coordinates"""
latitudes = np.linspace(-90, 90, nlat, endpoint=False)
longitudes = np.linspace(0, 360, nlon, endpoint=False)
return latitudes, longitudes

def create_regridder(self, file, compute_grid=None):
"""Create regridding instances"""
if compute_grid is None:
compute_grid = self.compute_gaussian_grid
self.input_path = self.config["input_path"]
self.rotation_file = self.config["rotation_file"]
self.output_path = self.config["output_path"]

# open input dataset with first file
ds_in = xr.open_dataset(file)
nlon_i = ds_in.sizes["yh"]
nlat_i = ds_in.sizes["xh"]
self.ires = f"weights-{nlon_i}x{nlat_i}"
self.ores = f"{self.nlon_o}x{self.nlat_o}"
ds_in_t = ds_in.rename({"xh": "lon", "yh": "lat"})
ds_in_u = ds_in.rename({"xq": "lon", "yh": "lat"})
ds_in_v = ds_in.rename({"xh": "lon", "yq": "lat"})

# open rotation dataset
file_rot = f"{self.rotation_file}"
ds_rot = xr.open_dataset(file_rot)
ds_rot = ds_rot[["cos_rot", "sin_rot"]]
self.ds_rot = ds_rot.rename({"xh": "lon", "yh": "lat"})

# create output dataset
self.lat1d, self.lon1d = compute_grid(self.nlat_o, self.nlon_o)
lons, lats = np.meshgrid(self.lon1d, self.lat1d)
da_out_lons = xr.DataArray(lons, dims=["nx", "ny"])
da_out_lats = xr.DataArray(lats, dims=["nx", "ny"])
ds_out_lons = da_out_lons.to_dataset(name="lon")
ds_out_lats = da_out_lats.to_dataset(name="lat")
grid_out = xr.merge([ds_out_lons, ds_out_lats])

# interpolation weights files
wgtsfile_t_to_t = (
f"{self.output_path}/{self.ires}.Ct.{self.ores}.Ct.bilinear.nc"
)
wgtsfile_u_to_t = (
f"{self.output_path}/{self.ires}.Cu.{self.ires}.Ct.bilinear.nc"
)
wgtsfile_v_to_t = (
f"{self.output_path}/{self.ires}.Cv.{self.ires}.Ct.bilinear.nc"
)

# define regridding instances
reuse = os.path.exists(wgtsfile_t_to_t)
self.rg_tt = xe.Regridder(
ds_in_t,
grid_out,
self.interp_method,
periodic=True,
reuse_weights=reuse,
filename=wgtsfile_t_to_t,
)
reuse = os.path.exists(wgtsfile_u_to_t)
self.rg_ut = xe.Regridder(
ds_in_u,
ds_in_t,
self.interp_method,
periodic=True,
reuse_weights=reuse,
filename=wgtsfile_u_to_t,
)
reuse = os.path.exists(wgtsfile_v_to_t)
self.rg_vt = xe.Regridder(
ds_in_v,
ds_in_t,
self.interp_method,
periodic=True,
reuse_weights=reuse,
filename=wgtsfile_v_to_t,
)

def regrid(self, file):
"""Regrid a single ocean file"""
print(f"Regridding file: {file}")
dtg = file[-16:-3]
ds_in = xr.open_dataset(file)
ds_out = []

for var in list(ds_in.keys()):
if len(ds_in[var].coords) > 2:
coords = ds_in[var].coords.to_index()

# choose regridding type
if coords.names[0] != "time":
raise ValueError("First coordinate should be time")
else:
variable_map = {
"SSU": ("SSV", "U"),
"SSV": (None, "skip"),
"uo": ("vo", "U"),
"vo": (None, "skip"),
"taux": ("tauy", "U"),
"tauy": (None, "skip"),
}
if var in variable_map.keys():
var2, pos = variable_map[var]
else:
var2, pos = (None, "T")

# 3-dimensional data
is_3d = False
if coords.names[1] == "z_l" or coords.names[1] == "zl":
dims = ["time", "lev", "lat", "lon"]
is_3d = True
else:
dims = ["time", "lat", "lon"]

if pos == "T":
interp_out = self.rg_tt(ds_in[var].values)
da_out = xr.DataArray(interp_out, dims=dims)
da_out.attrs["long_name"] = ds_in[var].long_name
da_out.attrs["units"] = ds_in[var].units
ds_out.append(da_out.to_dataset(name=var))

if pos == "U":
# interplate u and v to t-point, then rotate currents/winds to
# earth relative before interpolation

# interpolate to t-points
interp_u = self.rg_ut(ds_in[var].values)
interp_v = self.rg_ut(ds_in[var2].values)

# rotate to earth-relative
if is_3d:
urot = np.zeros(np.shape(interp_u)[1:])
vrot = np.zeros(np.shape(interp_u)[1:])
for k in range(np.shape(interp_u)[1]):
urot[k] = (
interp_u[0, k, :, :] * self.ds_rot.cos_rot
+ interp_v[0, k, :, :] * self.ds_rot.sin_rot
)
vrot[k] = (
interp_v[0, k, :, :] * self.ds_rot.cos_rot
- interp_u[0, k, :, :] * self.ds_rot.sin_rot
)
else:
urot = (
interp_u[0, :, :] * self.ds_rot.cos_rot
+ interp_v[0, :, :] * self.ds_rot.sin_rot
)
vrot = (
interp_v[0, :, :] * self.ds_rot.cos_rot
- interp_u[0, :, :] * self.ds_rot.sin_rot
)

# interoplate
uinterp_out = self.rg_tt(urot)
vinterp_out = self.rg_tt(vrot)
da_out = xr.DataArray(
np.expand_dims(uinterp_out, 0),
dims=dims,
)
da_out.attrs["long_name"] = ds_in[var].long_name
da_out.attrs["units"] = ds_in[var].units
ds_out.append(da_out.to_dataset(name=var))
da_out = xr.DataArray(
np.expand_dims(vinterp_out, 0),
dims=dims,
)
da_out.attrs["long_name"] = ds_in[var2].long_name
da_out.attrs["units"] = ds_in[var2].units
ds_out.append(da_out.to_dataset(name=var2))

ds_out = xr.merge(ds_out)
ds_out = ds_out.assign_coords(lon=("lon", self.lon1d))
ds_out = ds_out.assign_coords(lat=("lat", self.lat1d))
ds_out = ds_out.assign_coords(lev=("lev", ds_in.z_l.values))
ds_out = ds_out.assign_coords(time=("time", ds_in.time.values))
ds_out["lon"].attrs["units"] = "degrees_east"
ds_out["lon"].attrs["axis"] = "X"
ds_out["lon"].attrs["standard_name"] = "longitude"
ds_out["lat"].attrs["units"] = "degrees_north"
ds_out["lat"].attrs["axis"] = "Y"
ds_out["lat"].attrs["standard_name"] = "latitude"
ds_out["lev"].attrs["units"] = "meters"
ds_out["lev"].attrs["positive"] = "down"
ds_out["lev"].attrs["axis"] = "Z"
ds_out.to_netcdf(f"{self.output_path}/ocn_{dtg}_{self.ores}.nc")
ds_out.close()
ds_in.close()
del ds_out
del ds_in
27 changes: 27 additions & 0 deletions regrid/regrid_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: regrid
channels:
- conda-forge
dependencies:
- python=3.11
# Basics
- numpy
- scipy
- matplotlib
# xarray et al
- xarray
- netCDF4
- h5netcdf
- bottleneck
- dask[complete]
- zarr
- cftime
# regrid
- xesmf
# other plotting and interaction
- jupyterlab
- cartopy
- shapely
# filesystem
- fsspec
- s3fs
- gcsfs

0 comments on commit 679e33a

Please sign in to comment.