Skip to content

Commit

Permalink
improved behavior of subsample.mooring (#364)
Browse files Browse the repository at this point in the history
* refactor a bit

* format

* isort

* allow NoneType

* fix test

* black format

* restore

* fix drop var when XRange is None (but YRange is not None)

* def rel_lon when XRange is None

* black formatting

* improve comp

* format - black

* black

* black format

---------

Co-authored-by: Miguel Jimenez <mjimen17@jhu.edu>
  • Loading branch information
Miguel Jimenez and Mikejmnez authored May 26, 2023
1 parent cff27a3 commit a043890
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 75 deletions.
52 changes: 24 additions & 28 deletions oceanspy/llc_rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dask
import numpy as _np
import xarray as _xr
from xarray import DataArray, Dataset
from xgcm import Grid

from .utils import _rel_lon, _reset_range, get_maskH
Expand All @@ -32,10 +33,6 @@
]


_datype = _xr.core.dataarray.DataArray
_dstype = _xr.core.dataset.Dataset


class LLCtransformation:
"""A class containing the transformation types of LLCgrids."""

Expand Down Expand Up @@ -236,13 +233,13 @@ def arctic_crown(
ARCT[i] = _xr.merge(ARCT[i])

DSa2, DSa5, DSa7, DSa10 = ARCT
if type(DSa2) != _dstype:
if type(DSa2) != Dataset:
DSa2 = 0
if type(DSa5) != _dstype:
if type(DSa5) != Dataset:
DSa5 = 0
if type(DSa7) != _dstype:
if type(DSa7) != Dataset:
DSa7 = 0
if type(DSa10) != _dstype:
if type(DSa10) != Dataset:
DSa10 = 0

DSa7 = shift_dataset(DSa7, dims_c.X, dims_g.X)
Expand Down Expand Up @@ -311,11 +308,11 @@ def arctic_crown(

# Here, address shifts in Arctic
# arctic exchange with face 10
if type(faces2[0]) == _dstype:
if type(faces2[0]) == Dataset:
faces2[0]["Yp1"] = faces2[0]["Yp1"] + 1

# Arctic exchange with face 2
if type(faces3[3]) == _dstype:
if type(faces3[3]) == Dataset:
faces3[3]["Xp1"] = faces3[3]["Xp1"] + 1

# =====
Expand Down Expand Up @@ -383,8 +380,8 @@ def arctic_crown(
# First, check if there is data in both DSFacet12 and DSFacet34.
# If not, then there is no need to transpose data in DSFacet12.

if type(DSFacet12) == _dstype:
if type(DSFacet34) == _dstype:
if type(DSFacet12) == Dataset:
if type(DSFacet34) == Dataset:
# two lines below asserts correct
# staggering of center and corner points
# in latitude (otherwise, lat has a jump)
Expand Down Expand Up @@ -429,8 +426,7 @@ def arctic_crown(
if chunks:
DS = DS.chunk(chunks)

if XRange is not None and YRange is not None:
# drop copy var = 'nYg' (line 101)
if "nYG" in DS.reset_coords().data_vars:
DS = DS.drop_vars(_var_)

if geo_true:
Expand Down Expand Up @@ -688,7 +684,7 @@ def rotate_vars(_ds):
topology makes it so that u on a rotated face transforms to `+- v` on a lat lon
grid.
"""
if type(_ds) == _dstype: # if a dataset transform otherwise pass
if type(_ds) == Dataset: # if a dataset transform otherwise pass
_ds = _copy.deepcopy(_ds)
_vars = [var for var in _ds.variables]
rot_names = {}
Expand Down Expand Up @@ -716,7 +712,7 @@ def shift_dataset(_ds, dims_c, dims_g):
dims_c.
"""
if type(_ds) == _dstype: # if a dataset transform otherwise pass
if type(_ds) == Dataset: # if a dataset transform otherwise pass
_ds = _copy.deepcopy(_ds)
for _dim in [dims_c, dims_g]:
if int(_ds[_dim][0].data) < int(_ds[_dim][1].data):
Expand All @@ -737,7 +733,7 @@ def reverse_dataset(_ds, dims_c, dims_g, transpose=False):
so dims_c is either one of `i` or `j`, and dims_g is either one of `i_g` or `j_g`.
The pair most correspond to the same dimension."""

if type(_ds) == _dstype: # if a dataset transform otherwise pass
if type(_ds) == Dataset: # if a dataset transform otherwise pass
_ds = _copy.deepcopy(_ds)

for _dim in [dims_c, dims_g]: # This part should be different for j_g points?
Expand Down Expand Up @@ -771,7 +767,7 @@ def rotate_dataset(
nface=int: correct number to use. This is the case a merger/concatenated dataset is
being manipulated. Nij is no longer the size of the face.
"""
if type(_ds) == _dstype: # if a dataset transform otherwise pass
if type(_ds) == Dataset: # if a dataset transform otherwise pass
_ds = _copy.deepcopy(_ds)
Nij = max(len(_ds[dims_c.X]), len(_ds[dims_c.Y]))

Expand Down Expand Up @@ -832,7 +828,7 @@ def shift_list_ds(_DS, dims_c, dims_g, Ni, facet=1):
else:
for _dim in [dims_c, dims_g]:
dim0 = int(_DS[ii - 1][_dim][-1].data + 1)
if type(_DS[ii]) == _dstype:
if type(_DS[ii]) == Dataset:
for _dim in [dims_c, dims_g]:
_DS[ii]["n" + _dim] = (
_DS[ii][_dim] - (fac * int(_DS[ii][_dim][0].data)) + dim0
Expand All @@ -845,7 +841,7 @@ def shift_list_ds(_DS, dims_c, dims_g, Ni, facet=1):
)
DS = []
for lll in range(len(_DS)):
if type(_DS[lll]) == _dstype:
if type(_DS[lll]) == Dataset:
DS.append(_DS[lll])
else:
DS = _DS
Expand Down Expand Up @@ -886,7 +882,7 @@ def flip_v(_ds, co_list=metrics, dims=True, _len=3):
dims is given
"""
if type(_ds) == _dstype:
if type(_ds) == Dataset:
for _varName in _ds.variables:
if dims:
DIMS = [dim for dim in _ds[_varName].dims if dim != "face"]
Expand Down Expand Up @@ -993,12 +989,12 @@ def arc_limits_mask(_ds, _var, _faces, _dims, XRange, YRange):
ARCT[3].append(DS[3])

for i in range(len(ARCT)): # Not all faces survive the cutout
if type(ARCT[i][0]) == _datype:
if type(ARCT[i][0]) == DataArray:
ARCT[i] = _xr.merge(ARCT[i])

DSa2, DSa5, DSa7, DSa10 = ARCT

if type(DSa2) != _dstype:
if type(DSa2) != Dataset:
DSa2 = 0
[Xi_2, Xf_2] = [0, 0]
else:
Expand All @@ -1007,7 +1003,7 @@ def arc_limits_mask(_ds, _var, _faces, _dims, XRange, YRange):
else:
Xf_2 = _edge_arc_data(DSa2[_var], 2, _dims)
Xi_2 = int(DSa2[_var][_dims.X][0])
if type(DSa5) != _dstype:
if type(DSa5) != Dataset:
DSa5 = 0
[Yi_5, Yf_5] = [0, 0]
else:
Expand All @@ -1016,7 +1012,7 @@ def arc_limits_mask(_ds, _var, _faces, _dims, XRange, YRange):
else:
Yf_5 = _edge_arc_data(DSa5[_var], 5, _dims)
Yi_5 = int(DSa5[_var][_dims.Y][0])
if type(DSa7) != _dstype:
if type(DSa7) != Dataset:
DSa7 = 0
[Xi_7, Xf_7] = [0, 0]
else:
Expand All @@ -1026,7 +1022,7 @@ def arc_limits_mask(_ds, _var, _faces, _dims, XRange, YRange):
Xi_7 = _edge_arc_data(DSa7[_var], 7, _dims)
Xf_7 = int(DSa7[_var][_dims.X][-1])

if type(DSa10) != _dstype:
if type(DSa10) != Dataset:
DSa10 = 0
[Yi_10, Yf_10] = [0, 0]
else:
Expand Down Expand Up @@ -1056,7 +1052,7 @@ def _edge_facet_data(_Facet_list, _var, _dims, _axis):

XRange = []
for i in range(len(_Facet_list)):
if type(_Facet_list[i]) == _dstype:
if type(_Facet_list[i]) == Dataset:
# there is data
_da = _Facet_list[i][_var].load() # load into memory 2d data.
X0 = []
Expand Down Expand Up @@ -1091,7 +1087,7 @@ def slice_datasets(_DSfacet, dims_c, dims_g, _edges, _axis):
_DSFacet = _copy.deepcopy(_DSfacet)
for i in range(len(_DSFacet)):
# print(i)
if type(_DSFacet[i]) == _dstype:
if type(_DSFacet[i]) == Dataset:
for _dim in [_dim_c, _dim_g]:
if len(_edges) == 1:
ii_0 = int(_edges[0])
Expand Down
9 changes: 5 additions & 4 deletions oceanspy/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ def cutout(
# Drop variables
if varList is not None:
# Make sure it's a list
varList = list(varList)
varList = varList + co_list
varList = _rename_aliased(od, varList)
varList = _rename_aliased(od, list(varList) + co_list)

# Compute missing variables
od = _compute._add_missing_variables(od, varList)
Expand Down Expand Up @@ -402,7 +400,10 @@ def cutout(
# ---------------------------
# Initialize horizontal mask
if XRange is not None or YRange is not None:
XRange, ref_lon = _reset_range(XRange)
if XRange is not None:
XRange, ref_lon = _reset_range(XRange)
else:
ref_lon = 180
maskH, dmaskH, XRange, YRange = get_maskH(
ds, add_Hbdr, XRange, YRange, ref_lon=ref_lon
)
Expand Down
2 changes: 1 addition & 1 deletion oceanspy/tests/test_llc_rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,7 +2592,7 @@ def test_mask_var(od, XRange, YRange):
(od, P06_lon, P06_lat, [0, 0], [0, 0], [0, 0], [0, 0]),
(od, [-31, -2], [58, 68.2], [0, 3], [0, 0], [0, 0], [0, 0]),
(od, [160, -160], [58, 85.2], [0, 0], [0, 0], [52, 89], [0, 0]),
(od, [160, 100], [58, 85.2], [0, 0], [0, 39], [51, 89], [0, 0]),
(od, [160, 100], [58, 85.2], [0, 39], [0, 39], [51, 89], [51, 89]),
],
)
def test_arc_limits_mask(od, XRange, YRange, A, B, C, D):
Expand Down
20 changes: 15 additions & 5 deletions oceanspy/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def test_circle_path_array(lats, lons, symmetry, resolution):
coords1 = [[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]]
coords2 = [[[5, 0], [4, 1], [3, 2], [2, 3], [1, 4], [0, 5]]]
coords3 = [[[0, 6], [0, 7], [0, 8], [0, 9], [0, 10], [0, 11]]]
lons = []
coords4 = '[{"type":"Point","coordinates":[-169.23960833202577,22.865677261831266]}]'
coords5 = '[{"type":"Point","coordinates":[636.7225446274502, -56.11128546740994]}]'
coords6 = '[{"type":"Point","coordinates":[754.2277421326479, -57.34299561290217]}]'
coords7 = '[{"type":"Point","coordinates":[-424.42989807993234, 37.87263032287052]}]'


@pytest.mark.parametrize(
Expand All @@ -83,6 +87,9 @@ def test_circle_path_array(lats, lons, symmetry, resolution):
(coords2, "LineString", [[5, 0]], [[4, 1]]),
(coords3, "LineString", [[0, 6]], [[0, 7]]),
(coords4, "Point", [-169.23960833202577], [22.865677261831266]),
(coords5, "Point", [-83.27745537254975], [-56.11128546740994]),
(coords6, "Point", [34.227742132647904], [-57.34299561290217]),
(coords7, "Point", [-64.42989807993234], [37.87263032287052]),
],
)
def test_viewer_to_range(coords, types, lon, lat):
Expand Down Expand Up @@ -116,14 +123,17 @@ def test_viewer_to_range(coords, types, lon, lat):
(X2, X0, 53.67),
(X3, X3, 180),
(X4, X3, 180),
(X5, _np.array([161, 19]), 113.67),
(X6, _np.array([161, 19]), 113.67),
(X7, X7, 180),
(X5, None, 180),
(X6, None, 180),
(X7, X7, 6.67),
],
)
def test_reset_range(XRange, x0, expected_ref):
"""test the function rel_lon which redefines the reference long."""
x_range, ref_lon = _reset_range(XRange)
assert len(x_range) == 2
assert x_range.all() == x0.all()
if x0 is not None:
assert len(x_range) == 2
assert x_range.all() == x0.all()
else:
assert x_range is None
assert _np.round(ref_lon, 2) == expected_ref
89 changes: 52 additions & 37 deletions oceanspy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,16 @@ def viewer_to_range(p):
lon.append(coords[i][0])
lat.append(coords[i][1])

return lon, lat
# check that there are no lon values greater than 180 (abs)
ll = _np.where(abs(_np.array(lon)) > 180)[0]
if ll.size:
lon = _np.array(lon)
sign = _np.sign(lon[ll])
fac = _np.round(abs(lon)[ll] / 360)
nlon = lon[ll] - 360 * sign * fac
lon[ll] = nlon

return list(lon), lat


def _rel_lon(x, ref_lon):
Expand All @@ -108,7 +117,7 @@ def _rel_lon(x, ref_lon):
return (x - ref_lon) % 360


def _reset_range(x):
def _reset_range(xn):
"""Resets the definition of XRange, by default the discontinuity at 180 long.
Checks that there is no sign change in x and if there is, the only change that
is allowed is when crossing zero. Otherwise resets ref_lon.
Expand All @@ -127,41 +136,47 @@ def _reset_range(x):
redefined_x: numpy.array.
converted longitude.
"""

ref_lon = 180
if x is not None:
if (_np.sign(x) == _np.sign(x[0])).all(): # no sign change
_ref_lon = ref_lon
X0, X1 = _np.min(x), _np.max(x)
else: # change in sign
if len(x) == 2: # list of end points
X0, X1 = x
if x[0] > x[1]: # across discontinuity
if abs(x[1] - x[0]) > 300: # across a discontinuity (Delta X =360)
_ref_lon = x[0] - (x[0] - x[1]) / 3
else: # XRange decreases, but not necessarity a dicont.
_ref_lon = ref_lon
else:
_ref_lon = ref_lon
else: # array of values.
_del = abs(x[1:] - x[:-1]) # only works with one crossing
if len(_np.where(abs(_del) > 300)[0]) > 0: # there's discontinuity
ll = _np.where(_del == max(_del))[0][0]
if x[ll] > x[ll + 1]: # track starts west of jump
X0 = _np.min(x[: ll + 1])
X1 = _np.max(x[ll + 1 :])
else:
X0 = _np.min(x[ll + 1 :])
X1 = _np.max(x[: ll + 1])
_ref_lon = X0 - (X0 - X1) / 3
else: # no discontinuity
X0 = _np.min(x)
X1 = _np.max(x)
_ref_lon = ref_lon
x = _np.array([X0, X1])
else:
_ref_lon = ref_lon
return x, _np.round(_ref_lon, 2)
_ref_lon = 180
xn = _np.array(xn)
cross = _np.where(_np.diff(_np.sign(xn)))[0]
if cross.size and xn.size != 2:
ref = 180, 0
if cross.size == 1: # one sign change
d1 = [abs(abs(xn[cross[0]]) - i) for i in ref]
i0 = _np.argwhere(_np.array(d1) == min(d1))[0][0]
if i0 == 0: # Pacific
ll = _np.where(xn > 0)[0]
nxn = _copy.deepcopy(xn)
nxn[ll] = nxn[ll] - 360
X = _np.min(nxn) + 360, _np.max(nxn)
_ref_lon = X[0] - (X[0] - X[1]) / 3
else: # Atlantic
X = _np.min(xn), _np.max(xn)
if cross.size > 1: # 2 or more sign changes
da = [abs(abs(xn[i]) - 180) for i in cross]
db = [abs(abs(xn[i]) - 0) for i in cross]
d = _np.array([[da[i], db[i]] for i in range(len(da))])
ind = [_np.argwhere(d[i] == min(d[i]))[0][0] for i in range(len(d))]
if all(ind[0] == i for i in ind) and ind[0] == 0: # Pacific
ll = _np.where(xn > 0)[0]
nxn = _copy.deepcopy(xn)
nxn[ll] = nxn[ll] - 360
X = _np.min(nxn) + 360, _np.max(nxn)
_ref_lon = X[0] - (X[0] - X[1]) / 3
elif all(ind[0] == i for i in ind) and ind[0] == 1: # Atlantic
X = _np.min(xn), _np.max(xn)
else:
X = None
elif cross.size == 0 or xn.size == 2:
if xn.size == 2:
X = xn[0], xn[1]
if xn[0] > xn[1]:
_ref_lon = X[0] - (X[0] - X[1]) / 3
else:
X = _np.min(xn), _np.max(xn)
if X is not None:
X = _np.array(X)
return X, _np.round(_ref_lon, 2)


def spherical2cartesian(Y, X, R=None):
Expand Down

0 comments on commit a043890

Please sign in to comment.