Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seam issue when doing conservative regriding with xesmf #368

Open
scoatsclim opened this issue Aug 14, 2024 · 1 comment
Open

Seam issue when doing conservative regriding with xesmf #368

scoatsclim opened this issue Aug 14, 2024 · 1 comment

Comments

@scoatsclim
Copy link

scoatsclim commented Aug 14, 2024

import xarray as xr
import matplotlib.pyplot as plt
import intake
import cf_xarray as cfxr
import xesmf as xe
from xmip.preprocessing import combined_preprocessing

# Getting the datasets
url = "https://storage.googleapis.com/cmip6/pangeo-cmip6.json"
col = intake.open_esm_datastore(url)
models = ['CNRM-CM6-1-HR']
cat = col.search(table_id='Omon', grid_label='gn', experiment_id='historical', variable_id='tos', source_id=models)

# Loading data
cat.df['source_id'].unique()
z_kwargs = {'consolidated': True, 'decode_times': True} #, 'use_cftime': True}
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    dset_dict = cat.to_dataset_dict(zarr_kwargs=z_kwargs,
                                    preprocess=combined_preprocessing)

# Making and fixing coarse grid  
coarse_grid=xe.util.grid_global(1,1)
lon = coarse_grid["lon"].where(coarse_grid["lon"] > 0, 360 + coarse_grid["lon"])
lon_b = coarse_grid["lon_b"].where(coarse_grid["lon_b"] >= 0, 361 + coarse_grid["lon_b"])
lon = lon.sortby(lon[0,:])
lon_b = lon_b.sortby(lon_b[0,:])
coarse_grid=coarse_grid.assign_coords(lon_b=lon_b,lon=lon)

# Grabbing dataset
ds_in=dset_dict['CMIP.CNRM-CERFACS.CNRM-CM6-1-HR.historical.Omon.gn']
ds_in=ds_in.isel(member_id=0).squeeze()
ds_in=ds_in.sel(time=slice('1856','2014'))

# Input data setup for regrid
in_mask=ds_in['tos'][100,:,:].notnull()
ds_in_mask=ds_in.assign(mask=in_mask)
if "vertices_latitude" not in ds_in.variables:
    lat_corners=cfxr.bounds_to_vertices(ds_in.lat_verticies,'vertex')
    lon_corners=cfxr.bounds_to_vertices(ds_in.lon_verticies,'vertex')
else:
    lat_corners=cfxr.bounds_to_vertices(ds_in['vertices_latitude'],'vertex')
    lon_corners=cfxr.bounds_to_vertices(ds_in['vertices_longitude'],'vertex')
ds_in_mask.coords['lon_b']=lon_corners
ds_in_mask.coords['lat_b']=lat_corners

# Regridder
reg_mask=xe.Regridder(ds_in_mask,coarse_grid,'conservative',ignore_degenerate=True,periodic=True)

# Regrid
ds_out_siconc=reg_mask(ds_in_mask.squeeze())

# Plotting to show the seam issue:
fig, axs = plt.subplots(ncols=1, figsize=(12, 4))
inpl=ds_out_siconc['tos'][0,:,:]
inpl.plot(ax=axs, add_colorbar=True)

Julius told me he had a solution for this, sorry for the messy code. Thanks!

This code produces a seam in the regridded tos output.

Screenshot 2024-08-14 at 11 19 49 AM
@jbusecke
Copy link
Owner

jbusecke commented Aug 14, 2024

I am fairly sure this is due to the order of the vertex. I ran into this problem in another context.

Just dropping some code I used to fix this in a brute force manner (by trying each combination and testing the resulting vertex points).

import itertools
# might be missing some imports

def cmip_bounds_to_xesmf(ds: xr.Dataset, order=None):
    # the order is specific to the way I reorganized vertex order in xmip (if not passed we get the stripes in the regridded output!

    if not all(var in ds.variables for var in ["lon_b", "lat_b"]):
        ds = ds.assign_coords(
            lon_b=cf_xarray.bounds_to_vertices(
                ds.lon_verticies.load(), bounds_dim="vertex", order=order
            ),
            lat_b=cf_xarray.bounds_to_vertices(
                ds.lat_verticies.load(), bounds_dim="vertex", order=order
            ),
        )
    return ds

def test_vertex_order(ds):
    # pick a point in the southern hemisphere to avoid curving nonsense
    p = {"x": slice(20, 22), "y": slice(20, 22)}
    ds_p = ds.isel(**p).squeeze()
    # get rid of all the unneccesary variables
    for var in ds_p.variables:
        if (
            ("lev" in ds_p[var].dims)
            or ("time" in ds_p[var].dims)
            or (var in ["sub_experiment_label", "variant_label"])
        ):
            ds_p = ds_p.drop_vars(var)
    ds_p = cmip_bounds_to_xesmf(
        ds_p, order=None
    )  # woudld be nice if this could automatically get the settings provided to `cmip_bounds_to_xesmf`
    ds_p = ds_p.load().transpose(..., "x", "y", "vertex")
    if (
        not (ds_p.lon_b.diff("x_vertices") > 0).all()
        and (ds_p.lat_b.diff("y_vertices") > 0).all()
    ):
        raise ValueError("Test vertices not strictly monotinically increasing")

def reorder_vertex(ds, new_order):
    ds_wo_vertex = ds.drop_vars([va for va in ds.variables if 'vertex' in ds[va].dims])
    ds_w_vertex = ds.drop_vars([va for va in ds.variables if 'vertex' not in ds[va].dims])
    ds_w_vertex_reordered = xr.concat([ds_w_vertex.isel(vertex=i) for i in new_order], dim='vertex')
    return xr.merge([ds_w_vertex_reordered, ds_wo_vertex])

def get_order(ds):
    order = [0, 1, 2, 3]
    all_orders = itertools.permutations(order, len(order))
    for new_order in all_orders:
        ds_reordered = reorder_vertex(ds, new_order)
        try:
            test_vertex_order(ds_reordered)
            print(f"{new_order=} worked!")
            return new_order
        except:
            pass

from xmip.utils import cmip6_dataset_id
import warnings
def test_and_reorder_vertex(ds):
    """This is an expensive check that tries every possible order of the vertex and confirms 
    that we get strictly monontonic lon_b/lat_b coordinates for a test point.
    """
    
    new_order = get_order(ds)
    if new_order is None:
        # drop them, maybe another one works better? This is a nightmare TBH.
        ds_out = ds.drop_vars([va for va in ds.variables if 'vertex' in ds[va].dims])
        print(f"Unable to find a vertex order for {cmip6_dataset_id(ds)}")
        # raise ValueError(f"Unable to find a vertex order for {cmip6_dataset_id(ds)}")
    else:
        print(f"Changing vertex order for {cmip6_dataset_id(ds)}")
        ds_out = reorder_vertex(ds, new_order)
    return ds_out

This cannot be the most elegant solution, but Ill try to work on this some time soon.

@jbusecke jbusecke mentioned this issue Aug 27, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants