Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring GridCode to be called GridType #1615

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DeferredNetcdfFileBuffer,
NetcdfFileBuffer,
)
from .grid import CGrid, Grid, GridCode
from .grid import CGrid, Grid, GridType

__all__ = ['Field', 'VectorField', 'NestedField']

Expand Down Expand Up @@ -178,7 +178,7 @@ def __init__(self, name, data, lon=None, lat=None, depth=None, time=None, grid=N
self.interp_method = interp_method
self.gridindexingtype = gridindexingtype
if self.interp_method in ['bgrid_velocity', 'bgrid_w_velocity', 'bgrid_tracer'] and \
self.grid.gtype in [GridCode.RectilinearSGrid, GridCode.CurvilinearSGrid]:
self.grid.gtype in [GridType.RectilinearSGrid, GridType.CurvilinearSGrid]:
logger.warning_once('General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal.')

self.fieldset = None
Expand Down Expand Up @@ -687,7 +687,7 @@ def calc_cell_edge_sizes(self):
Currently only works for Rectilinear Grids
"""
if not self.grid.cell_edge_sizes:
if self.grid.gtype in (GridCode.RectilinearZGrid, GridCode.RectilinearSGrid):
if self.grid.gtype in (GridType.RectilinearZGrid, GridType.RectilinearSGrid):
self.grid.cell_edge_sizes['x'] = np.zeros((self.grid.ydim, self.grid.xdim), dtype=np.float32)
self.grid.cell_edge_sizes['y'] = np.zeros((self.grid.ydim, self.grid.xdim), dtype=np.float32)

Expand Down Expand Up @@ -877,15 +877,15 @@ def search_indices_rectilinear(self, x, y, z, ti=-1, time=-1, particle=None, sea
yi, eta = -1, 0

if grid.zdim > 1 and not search2D:
if grid.gtype == GridCode.RectilinearZGrid:
if grid.gtype == GridType.RectilinearZGrid:
# Never passes here, because in this case, we work with scipy
try:
(zi, zeta) = self.search_indices_vertical_z(z)
except FieldOutOfBoundError:
raise FieldOutOfBoundError(x, y, z, field=self)
except FieldOutOfBoundSurfaceError:
raise FieldOutOfBoundSurfaceError(x, y, z, field=self)
elif grid.gtype == GridCode.RectilinearSGrid:
elif grid.gtype == GridType.RectilinearSGrid:
(zi, zeta) = self.search_indices_vertical_s(x, y, z, xi, yi, xsi, eta, ti, time)
else:
zi, zeta = -1, 0
Expand Down Expand Up @@ -973,12 +973,12 @@ def search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, sea
eta = min(1., eta)

if grid.zdim > 1 and not search2D:
if grid.gtype == GridCode.CurvilinearZGrid:
if grid.gtype == GridType.CurvilinearZGrid:
try:
(zi, zeta) = self.search_indices_vertical_z(z)
except FieldOutOfBoundError:
raise FieldOutOfBoundError(x, y, z, field=self)
elif grid.gtype == GridCode.CurvilinearSGrid:
elif grid.gtype == GridType.CurvilinearSGrid:
(zi, zeta) = self.search_indices_vertical_s(x, y, z, xi, yi, xsi, eta, ti, time)
else:
zi = -1
Expand All @@ -995,7 +995,7 @@ def search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, sea
return (xsi, eta, zeta, xi, yi, zi)

def search_indices(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False):
if self.grid.gtype in [GridCode.RectilinearSGrid, GridCode.RectilinearZGrid]:
if self.grid.gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
return self.search_indices_rectilinear(x, y, z, ti, time, particle=particle, search2D=search2D)
else:
return self.search_indices_curvilinear(x, y, z, ti, time, particle=particle, search2D=search2D)
Expand Down Expand Up @@ -1399,12 +1399,12 @@ def write(self, filename, varname=None):
vname_depth = 'depth%s' % self.name.lower()

# Create DataArray objects for file I/O
if self.grid.gtype == GridCode.RectilinearZGrid:
if self.grid.gtype == GridType.RectilinearZGrid:
nav_lon = xr.DataArray(self.grid.lon + np.zeros((self.grid.ydim, self.grid.xdim), dtype=np.float32),
coords=[('y', self.grid.lat), ('x', self.grid.lon)])
nav_lat = xr.DataArray(self.grid.lat.reshape(self.grid.ydim, 1) + np.zeros(self.grid.xdim, dtype=np.float32),
coords=[('y', self.grid.lat), ('x', self.grid.lon)])
elif self.grid.gtype == GridCode.CurvilinearZGrid:
elif self.grid.gtype == GridType.CurvilinearZGrid:
nav_lon = xr.DataArray(self.grid.lon, coords=[('y', range(self.grid.ydim)),
('x', range(self.grid.xdim))])
nav_lat = xr.DataArray(self.grid.lat, coords=[('y', range(self.grid.ydim)),
Expand Down Expand Up @@ -1553,7 +1553,7 @@ def spatial_c_grid_interpolation2D(self, ti, z, y, x, time, particle=None, apply
grid = self.U.grid
(xsi, eta, zeta, xi, yi, zi) = self.U.search_indices(x, y, z, ti, time, particle=particle)

if grid.gtype in [GridCode.RectilinearSGrid, GridCode.RectilinearZGrid]:
if grid.gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
px = np.array([grid.lon[xi], grid.lon[xi+1], grid.lon[xi+1], grid.lon[xi]])
py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi+1], grid.lat[yi+1]])
else:
Expand Down Expand Up @@ -1621,7 +1621,7 @@ def spatial_c_grid_interpolation3D_full(self, ti, z, y, x, time, particle=None):
grid = self.U.grid
(xsi, eta, zet, xi, yi, zi) = self.U.search_indices(x, y, z, ti, time, particle=particle)

if grid.gtype in [GridCode.RectilinearSGrid, GridCode.RectilinearZGrid]:
if grid.gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
px = np.array([grid.lon[xi], grid.lon[xi+1], grid.lon[xi+1], grid.lon[xi]])
py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi+1], grid.lat[yi+1]])
else:
Expand Down Expand Up @@ -1737,7 +1737,7 @@ def spatial_c_grid_interpolation3D(self, ti, z, y, x, time, particle=None, apply
interpolating linearly V depending on the latitude coordinate.
Curvilinear grids are treated properly, since the element is projected to a rectilinear parent element.
"""
if self.U.grid.gtype in [GridCode.RectilinearSGrid, GridCode.CurvilinearSGrid]:
if self.U.grid.gtype in [GridType.RectilinearSGrid, GridType.CurvilinearSGrid]:
(u, v, w) = self.spatial_c_grid_interpolation3D_full(ti, z, y, x, time, particle=particle)
else:
(u, v) = self.spatial_c_grid_interpolation2D(ti, z, y, x, time, particle=particle)
Expand Down
12 changes: 6 additions & 6 deletions parcels/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from parcels.tools.converters import TimeConverter
from parcels.tools.loggers import logger

__all__ = ['GridCode', 'RectilinearZGrid', 'RectilinearSGrid', 'CurvilinearZGrid', 'CurvilinearSGrid', 'CGrid', 'Grid']
__all__ = ['GridType', 'RectilinearZGrid', 'RectilinearSGrid', 'CurvilinearZGrid', 'CurvilinearSGrid', 'CGrid', 'Grid']
erikvansebille marked this conversation as resolved.
Show resolved Hide resolved


class GridCode(IntEnum):
class GridType(IntEnum):
RectilinearZGrid = 0
RectilinearSGrid = 1
CurvilinearZGrid = 2
Expand Down Expand Up @@ -326,7 +326,7 @@ def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh='flat
if isinstance(depth, np.ndarray):
assert (len(depth.shape) <= 1), 'depth is not a vector'

self.gtype = GridCode.RectilinearZGrid
self.gtype = GridType.RectilinearZGrid
self.depth = np.zeros(1, dtype=np.float32) if depth is None else depth
if not self.depth.flags['C_CONTIGUOUS']:
self.depth = np.array(self.depth, order='C')
Expand Down Expand Up @@ -371,7 +371,7 @@ def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh='flat'):
super().__init__(lon, lat, time, time_origin, mesh)
assert (isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4]), 'depth is not a 3D or 4D numpy array'

self.gtype = GridCode.RectilinearSGrid
self.gtype = GridType.RectilinearSGrid
self.depth = depth
if not self.depth.flags['C_CONTIGUOUS']:
self.depth = np.array(self.depth, order='C')
Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh='flat
if isinstance(depth, np.ndarray):
assert (len(depth.shape) == 1), 'depth is not a vector'

self.gtype = GridCode.CurvilinearZGrid
self.gtype = GridType.CurvilinearZGrid
self.depth = np.zeros(1, dtype=np.float32) if depth is None else depth
if not self.depth.flags['C_CONTIGUOUS']:
self.depth = np.array(self.depth, order='C')
Expand Down Expand Up @@ -525,7 +525,7 @@ def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh='flat'):
super().__init__(lon, lat, time, time_origin, mesh)
assert (isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4]), 'depth is not a 4D numpy array'

self.gtype = GridCode.CurvilinearSGrid
self.gtype = GridType.CurvilinearSGrid
self.depth = depth # should be a C-contiguous array of floats
if not self.depth.flags['C_CONTIGUOUS']:
self.depth = np.array(self.depth, order='C')
Expand Down
18 changes: 9 additions & 9 deletions parcels/include/index_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ typedef enum
typedef enum
{
RECTILINEAR_Z_GRID=0, RECTILINEAR_S_GRID=1, CURVILINEAR_Z_GRID=2, CURVILINEAR_S_GRID=3
} GridCode;
} GridType;

// equal/closeness comparison that is equal to numpy (double)
static inline bool is_close_dbl(double a, double b) {
Expand Down Expand Up @@ -206,7 +206,7 @@ static inline void reconnect_bnd_indices(int *xi, int *yi, int xdim, int ydim, i
}


static inline StatusCode search_indices_rectilinear(type_coord x, type_coord y, type_coord z, CStructuredGrid *grid, GridCode gcode,
static inline StatusCode search_indices_rectilinear(type_coord x, type_coord y, type_coord z, CStructuredGrid *grid, GridType gtype,
int *xi, int *yi, int *zi, double *xsi, double *eta, double *zeta,
int ti, double time, double t0, double t1, int interp_method,
int gridindexingtype)
Expand Down Expand Up @@ -284,7 +284,7 @@ static inline StatusCode search_indices_rectilinear(type_coord x, type_coord y,

StatusCode status;
if (zdim > 1){
switch(gcode){
switch(gtype){
case RECTILINEAR_Z_GRID:
status = search_indices_vertical_z(z, zdim, zvals, zi, zeta, gridindexingtype);
break;
Expand Down Expand Up @@ -316,7 +316,7 @@ static inline StatusCode search_indices_rectilinear(type_coord x, type_coord y,
}


static inline StatusCode search_indices_curvilinear(type_coord x, type_coord y, type_coord z, CStructuredGrid *grid, GridCode gcode,
static inline StatusCode search_indices_curvilinear(type_coord x, type_coord y, type_coord z, CStructuredGrid *grid, GridType gtype,
int *xi, int *yi, int *zi, double *xsi, double *eta, double *zeta,
int ti, double time, double t0, double t1, int interp_method,
int gridindexingtype)
Expand Down Expand Up @@ -428,7 +428,7 @@ static inline StatusCode search_indices_curvilinear(type_coord x, type_coord y,

StatusCode status;
if (zdim > 1){
switch(gcode){
switch(gtype){
case CURVILINEAR_Z_GRID:
status = search_indices_vertical_z(z, zdim, zvals, zi, zeta, gridindexingtype);
break;
Expand Down Expand Up @@ -457,18 +457,18 @@ static inline StatusCode search_indices_curvilinear(type_coord x, type_coord y,
* */
static inline StatusCode search_indices(type_coord x, type_coord y, type_coord z, CStructuredGrid *grid,
int *xi, int *yi, int *zi, double *xsi, double *eta, double *zeta,
GridCode gcode, int ti, double time, double t0, double t1, int interp_method,
GridType gtype, int ti, double time, double t0, double t1, int interp_method,
int gridindexingtype)
{
switch(gcode){
switch(gtype){
case RECTILINEAR_Z_GRID:
case RECTILINEAR_S_GRID:
return search_indices_rectilinear(x, y, z, grid, gcode, xi, yi, zi, xsi, eta, zeta,
return search_indices_rectilinear(x, y, z, grid, gtype, xi, yi, zi, xsi, eta, zeta,
ti, time, t0, t1, interp_method, gridindexingtype);
break;
case CURVILINEAR_Z_GRID:
case CURVILINEAR_S_GRID:
return search_indices_curvilinear(x, y, z, grid, gcode, xi, yi, zi, xsi, eta, zeta,
return search_indices_curvilinear(x, y, z, grid, gtype, xi, yi, zi, xsi, eta, zeta,
ti, time, t0, t1, interp_method, gridindexingtype);
break;
default:
Expand Down
Loading
Loading