From 70de5ae6f7a13d1c4b38992d69bd02f27c2b51b2 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 25 Aug 2025 15:28:57 +0200 Subject: [PATCH 1/6] Adding field._load_timesteps method --- parcels/field.py | 33 +++++++++++++++++++++++---------- parcels/fieldset.py | 7 +++++++ parcels/particleset.py | 7 +++++-- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 7d14eff60a..84abc6b068 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -127,7 +127,7 @@ def __init__( data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid) self.name = name - self.data = data + self.data_full = data self.grid = grid try: @@ -162,8 +162,8 @@ def __init__( elif self.grid._mesh == "spherical": self.units = unitconverters_map[self.name] - if self.data.shape[0] > 1: - if "time" not in self.data.coords: + if data.shape[0] > 1: + if "time" not in data.coords: raise ValueError("Field data is missing a 'time' coordinate.") @property @@ -178,27 +178,27 @@ def units(self, value): @property def xdim(self): - if type(self.data) is xr.DataArray: + if type(self.data_full) is xr.DataArray: return self.grid.xdim else: raise NotImplementedError("xdim not implemented for unstructured grids") @property def ydim(self): - if type(self.data) is xr.DataArray: + if type(self.data_full) is xr.DataArray: return self.grid.ydim else: raise NotImplementedError("ydim not implemented for unstructured grids") @property def zdim(self): - if type(self.data) is xr.DataArray: + if type(self.data_full) is xr.DataArray: return self.grid.zdim else: - if "nz1" in self.data.dims: - return self.data.sizes["nz1"] - elif "nz" in self.data.dims: - return self.data.sizes["nz"] + if "nz1" in self.data_full.dims: + return self.data_full.sizes["nz1"] + elif "nz" in self.data_full.dims: + return self.data_full.sizes["nz"] else: return 0 @@ -219,6 +219,19 @@ def _check_velocitysampling(self): stacklevel=2, ) + def _load_timesteps(self, time): + """Load the appropriate timesteps of a field.""" + ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0 + if not hasattr(self, "data"): + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + elif self.data_full.time.data[ti] == self.data.time.data[1]: + self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time") + elif self.data_full.time.data[ti] != self.data.time.data[0]: + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + assert len(self.data.time) == 2, ( + f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}." + ) + def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): """Interpolate field values in space and time. diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 113ef637da..9b32c95d46 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -82,6 +82,13 @@ def time_interval(self): return None return functools.reduce(lambda x, y: x.intersection(y), time_intervals) + def _load_timesteps(self, time): + """Load the appropriate timesteps of all fields in the fieldset.""" + for fldname in self.fields: + field = self.fields[fldname] + if isinstance(field, Field): + field._load_timesteps(time) + def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. diff --git a/parcels/particleset.py b/parcels/particleset.py index 828f85c5ed..5cbe240dca 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -578,10 +578,13 @@ def execute( time = start_time while sign_dt * (time - end_time) < 0: + # Load the appropriate timesteps of the fieldset + self.fieldset._load_timesteps(self._data["time_nextloop"][0]) + if sign_dt > 0: - next_time = end_time # TODO update to min(next_output, end_time) when ParticleFile works + next_time = min(time + dt, end_time) else: - next_time = end_time # TODO update to max(next_output, end_time) when ParticleFile works + next_time = max(time - dt, end_time) self._kernel.execute(self, endtime=next_time, dt=dt) # TODO: Handle IO timing based of timedelta or datetime objects From cda41b321778a91271a8eb8f26618a31b4f4a5ed Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Fri, 29 Aug 2025 09:59:57 +0200 Subject: [PATCH 2/6] Fixing unit tests for _load_timesteps --- parcels/_index_search.py | 4 ++-- parcels/field.py | 25 +++++++++++++++---------- parcels/fieldset.py | 4 +++- parcels/particleset.py | 6 +++--- tests/v4/test_advection.py | 10 ++++++---- 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 6d44217efe..d720ce7c8f 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -45,8 +45,8 @@ def _search_time_index(field: Field, time: datetime): if not field.time_interval.is_all_time_in_interval(time): _raise_time_extrapolation_error(time, field=None) - ti = np.searchsorted(field.data.time.data, time, side="right") - 1 - tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti]) + ti = np.zeros_like(time, dtype=np.int32) # TODO since ti is always zero, it can be removed? + tau = (time - field.data.time.data[0]) / (field.data.time.data[1] - field.data.time.data[0]) return np.atleast_1d(tau), np.atleast_1d(ti) diff --git a/parcels/field.py b/parcels/field.py index 84abc6b068..88676f1d98 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -221,16 +221,21 @@ def _check_velocitysampling(self): def _load_timesteps(self, time): """Load the appropriate timesteps of a field.""" - ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0 - if not hasattr(self, "data"): - self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() - elif self.data_full.time.data[ti] == self.data.time.data[1]: - self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time") - elif self.data_full.time.data[ti] != self.data.time.data[0]: - self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() - assert len(self.data.time) == 2, ( - f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}." - ) + if hasattr(self.data_full, "time") and len(self.data_full.time) > 2: + ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0 + if not hasattr(self, "data"): + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + elif self.data_full.time.data[ti] == self.data.time.data[1]: + self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time") + elif self.data_full.time.data[ti] != self.data.time.data[0]: + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + assert len(self.data.time) == 2, ( + f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}." + ) + return self.data_full.time.data[ti + 1] + else: + self.data = self.data_full + return None def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): """Interpolate field values in space and time. diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 9b32c95d46..9de11354f1 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -84,10 +84,12 @@ def time_interval(self): def _load_timesteps(self, time): """Load the appropriate timesteps of all fields in the fieldset.""" + next_time = 0 for fldname in self.fields: field = self.fields[fldname] if isinstance(field, Field): - field._load_timesteps(time) + next_time = min(next_time, field._load_timesteps(time)) + return next_time def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. diff --git a/parcels/particleset.py b/parcels/particleset.py index 5cbe240dca..5d9d2bab96 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -579,12 +579,12 @@ def execute( time = start_time while sign_dt * (time - end_time) < 0: # Load the appropriate timesteps of the fieldset - self.fieldset._load_timesteps(self._data["time_nextloop"][0]) + next_load_time = self.fieldset._load_timesteps(time) if sign_dt > 0: - next_time = min(time + dt, end_time) + next_time = min(next_load_time, end_time) else: - next_time = max(time - dt, end_time) + next_time = max(next_load_time, end_time) self._kernel.execute(self, endtime=next_time, dt=dt) # TODO: Handle IO timing based of timedelta or datetime objects diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index 2522a8b6f5..90b07c38b3 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -88,9 +88,9 @@ def test_horizontal_advection_in_3D_flow(npart=10): """Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s.""" ds = simple_UV_dataset(mesh="flat") ds["U"].data[:] = 1.0 + ds["U"].data[:, 0, :, :] = 0.0 # Set U to 0 at the surface grid = XGrid.from_dataset(ds) U = Field("U", ds["U"], grid, interp_method=XLinear) - U.data[:, 0, :, :] = 0.0 # Set U to 0 at the surface V = Field("V", ds["V"], grid, interp_method=XLinear) UV = VectorField("UV", U, V) fieldset = FieldSet([U, V, UV]) @@ -106,12 +106,13 @@ def test_horizontal_advection_in_3D_flow(npart=10): @pytest.mark.parametrize("wErrorThroughSurface", [True, False]) def test_advection_3D_outofbounds(direction, wErrorThroughSurface): ds = simple_UV_dataset(mesh="flat") + ds["U"].data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds) + ds["W"] = ds["V"].copy() # Use V as W for testing + ds["W"].data[:] = -1.0 if direction == "up" else 1.0 grid = XGrid.from_dataset(ds) U = Field("U", ds["U"], grid, interp_method=XLinear) - U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds) V = Field("V", ds["V"], grid, interp_method=XLinear) - W = Field("W", ds["V"], grid, interp_method=XLinear) # Use V as W for testing - W.data[:] = -1.0 if direction == "up" else 1.0 + W = Field("W", ds["W"], grid, interp_method=XLinear) UVW = VectorField("UVW", U, V, W) UV = VectorField("UV", U, V) fieldset = FieldSet([U, V, W, UVW, UV]) @@ -191,6 +192,7 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read fields = [U, V, VectorField("UV", U, V)] if w: W = Field("W", ds["W"], grid, interp_method=XLinear) + fields.append(W) fields.append(VectorField("UVW", U, V, W)) fieldset = FieldSet(fields) From 7d1d001630ff9f39d34d771c1cca925bd1901ac4 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Fri, 29 Aug 2025 11:31:56 +0200 Subject: [PATCH 3/6] Fixing unit tests By only using _load_timesteps when is_dask_collection(data) --- parcels/_index_search.py | 5 +++-- parcels/field.py | 16 ++++++++++------ parcels/fieldset.py | 9 ++++++--- parcels/particleset.py | 4 ++-- tests/v4/test_uxarray_fieldset.py | 12 ++++++------ 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index d720ce7c8f..590496c927 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -45,8 +45,9 @@ def _search_time_index(field: Field, time: datetime): if not field.time_interval.is_all_time_in_interval(time): _raise_time_extrapolation_error(time, field=None) - ti = np.zeros_like(time, dtype=np.int32) # TODO since ti is always zero, it can be removed? - tau = (time - field.data.time.data[0]) / (field.data.time.data[1] - field.data.time.data[0]) + # TODO this could be sped up when data has only two timeslices (i.e. when data_full is not None)? + ti = np.searchsorted(field.data.time.data, time, side="right") - 1 + tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti]) return np.atleast_1d(tau), np.atleast_1d(ti) diff --git a/parcels/field.py b/parcels/field.py index 88676f1d98..f30c7914ae 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -8,6 +8,7 @@ import numpy as np import uxarray as ux import xarray as xr +from dask import is_dask_collection from parcels._core.utils.time import TimeInterval from parcels._reprs import default_repr @@ -127,8 +128,14 @@ def __init__( data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid) self.name = name - self.data_full = data self.grid = grid + if is_dask_collection(data): + self.data = None + self.data_full = data + else: + self.data = data + self.data_full = None + self._nexttime_to_load = None try: self.time_interval = _get_time_interval(data) @@ -221,7 +228,7 @@ def _check_velocitysampling(self): def _load_timesteps(self, time): """Load the appropriate timesteps of a field.""" - if hasattr(self.data_full, "time") and len(self.data_full.time) > 2: + if self.data_full is not None: ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0 if not hasattr(self, "data"): self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() @@ -232,10 +239,7 @@ def _load_timesteps(self, time): assert len(self.data.time) == 2, ( f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}." ) - return self.data_full.time.data[ti + 1] - else: - self.data = self.data_full - return None + self._nexttime_to_load = self.data_full.time.data[ti + 1] def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): """Interpolate field values in space and time. diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 9de11354f1..6f666c3c60 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -84,12 +84,15 @@ def time_interval(self): def _load_timesteps(self, time): """Load the appropriate timesteps of all fields in the fieldset.""" - next_time = 0 + next_times = [] for fldname in self.fields: field = self.fields[fldname] if isinstance(field, Field): - next_time = min(next_time, field._load_timesteps(time)) - return next_time + field._load_timesteps(time) + if field._nexttime_to_load is not None: + next_times.append(field._nexttime_to_load) + + return min(next_times) if next_times else None def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. diff --git a/parcels/particleset.py b/parcels/particleset.py index 5d9d2bab96..ee903b5293 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -582,9 +582,9 @@ def execute( next_load_time = self.fieldset._load_timesteps(time) if sign_dt > 0: - next_time = min(next_load_time, end_time) + next_time = end_time if next_load_time is None else min(next_load_time, end_time) else: - next_time = max(next_load_time, end_time) + next_time = end_time if next_load_time is None else max(next_load_time, end_time) self._kernel.execute(self, endtime=next_time, dt=dt) # TODO: Handle IO timing based of timedelta or datetime objects diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index d3c8b5de62..53d91085cf 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -78,16 +78,16 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField: def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() + assert (fieldset.U.data_full == ds_fesom_channel.U).all() + assert (fieldset.V.data_full == ds_fesom_channel.V).all() def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() + assert (fieldset.U.data_full == ds_fesom_channel.U).all() + assert (fieldset.V.data_full == ds_fesom_channel.V).all() pset = ParticleSet(fieldset, pclass=Particle) assert pset.fieldset == fieldset @@ -95,8 +95,8 @@ def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() + assert (fieldset.U.data_full == ds_fesom_channel.U).all() + assert (fieldset.V.data_full == ds_fesom_channel.V).all() # Set the interpolation method for each field fieldset.U.interp_method = UXPiecewiseConstantFace From edb0a4430a8191d019fec6b203cf5b898d31fe9c Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Fri, 29 Aug 2025 11:57:02 +0200 Subject: [PATCH 4/6] Fixing bug in field._load_timesteps() --- parcels/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/field.py b/parcels/field.py index f30c7914ae..f76331eea2 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -230,7 +230,7 @@ def _load_timesteps(self, time): """Load the appropriate timesteps of a field.""" if self.data_full is not None: ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0 - if not hasattr(self, "data"): + if self.data is None: self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() elif self.data_full.time.data[ti] == self.data.time.data[1]: self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time") From 1623ce200930bd141e2a718085ceb204777c4deb Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Fri, 29 Aug 2025 12:07:30 +0200 Subject: [PATCH 5/6] Fixing field xdim, ydim, zdim attributes --- parcels/field.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index f76331eea2..c4565347f6 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -185,27 +185,31 @@ def units(self, value): @property def xdim(self): - if type(self.data_full) is xr.DataArray: + if hasattr(self.grid, "xdim"): return self.grid.xdim else: raise NotImplementedError("xdim not implemented for unstructured grids") @property def ydim(self): - if type(self.data_full) is xr.DataArray: + if hasattr(self.grid, "ydim"): return self.grid.ydim else: raise NotImplementedError("ydim not implemented for unstructured grids") @property def zdim(self): - if type(self.data_full) is xr.DataArray: + if hasattr(self.grid, "zdim"): return self.grid.zdim else: if "nz1" in self.data_full.dims: return self.data_full.sizes["nz1"] + elif "nz1" in self.data.dims: + return self.data.sizes["nz1"] elif "nz" in self.data_full.dims: return self.data_full.sizes["nz"] + elif "nz" in self.data.dims: + return self.data.sizes["nz"] else: return 0 From 90e9d601b9c015346711ba29464b46c85ca383c6 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Fri, 12 Sep 2025 08:14:32 +0200 Subject: [PATCH 6/6] Fix bug when time is not in data.dims --- parcels/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/field.py b/parcels/field.py index b4b36f0c98..24ba915832 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -134,7 +134,7 @@ def __init__( self.name = name self.grid = grid - if is_dask_collection(data): + if is_dask_collection(data) and ("time" in data.dims): self.data = None self.data_full = data else: