Skip to content

Commit

Permalink
Merge pull request #1810 from OceanParcels/warn_particle_times_outsid…
Browse files Browse the repository at this point in the history
…e_fieldset_time_bounds

Implement warning for particles initialised outside time domain
  • Loading branch information
erikvansebille authored Jan 7, 2025
2 parents 70e26eb + baa5f91 commit f46b2be
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 12 deletions.
26 changes: 22 additions & 4 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from parcels._compat import MPI
from parcels.application_kernels.advection import AdvectionRK4
from parcels.compilation.codecompiler import GNUCompiler
from parcels.field import NestedField
from parcels.field import Field, NestedField
from parcels.grid import CurvilinearGrid, GridType
from parcels.interaction.interactionkernel import InteractionKernel
from parcels.interaction.neighborsearch import (
Expand All @@ -32,7 +32,7 @@
from parcels.tools.global_statics import get_package_dir
from parcels.tools.loggers import logger
from parcels.tools.statuscodes import StatusCode
from parcels.tools.warnings import FileWarning
from parcels.tools.warnings import ParticleSetWarning

__all__ = ["ParticleSet"]

Expand Down Expand Up @@ -174,6 +174,8 @@ def ArrayClass_init(self, *args, **kwargs):
raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double")
time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time])
assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths."
if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation):
_warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full)

if lonlatdepth_dtype is None:
lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U)
Expand Down Expand Up @@ -792,7 +794,7 @@ def from_particlefile(
f"Note that the `repeatdt` argument is not retained from {filename}, and that "
"setting a new repeatdt will start particles from the _new_ particle "
"locations.",
FileWarning,
ParticleSetWarning,
stacklevel=2,
)

Expand Down Expand Up @@ -1247,6 +1249,22 @@ def _warn_outputdt_release_desync(outputdt: float, starttime: float, release_tim
"Some of the particles have a start time difference that is not a multiple of outputdt. "
"This could cause the first output of some of the particles that start later "
"in the simulation to be at a different time than expected.",
FileWarning,
ParticleSetWarning,
stacklevel=2,
)


def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, time_full: np.ndarray):
if np.any(release_times):
if np.any(release_times < time_full[0]):
warnings.warn(
"Some particles are set to be released before the fieldset's first time and allow_time_extrapolation is set to False.",
ParticleSetWarning,
stacklevel=2,
)
if np.any(release_times > time_full[-1]):
warnings.warn(
"Some particles are set to be released after the fieldset's last time and allow_time_extrapolation is set to False.",
ParticleSetWarning,
stacklevel=2,
)
8 changes: 7 additions & 1 deletion parcels/tools/warnings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

__all__ = ["FieldSetWarning", "FileWarning", "KernelWarning"]
__all__ = ["FieldSetWarning", "FileWarning", "KernelWarning", "ParticleSetWarning"]


class FieldSetWarning(UserWarning):
Expand All @@ -13,6 +13,12 @@ class FieldSetWarning(UserWarning):
pass


class ParticleSetWarning(UserWarning):
"""Warning that is raised when there are issues in the construction of the ParticleSet."""

pass


class FileWarning(UserWarning):
"""Warning that is raised when there are issues with input or output files.
Expand Down
9 changes: 9 additions & 0 deletions tests/test_particlesets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
FieldSet,
JITParticle,
ParticleSet,
ParticleSetWarning,
ScipyParticle,
StatusCode,
Variable,
Expand Down Expand Up @@ -175,6 +176,14 @@ def test_pset_create_with_time(fieldset, mode):
assert np.allclose([p.time for p in pset], time, rtol=1e-12)


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_pset_create_outside_time(mode):
fieldset = create_fieldset_zeros_simple(withtime=True)
time = [-1, 0, 1, 20 * 86400]
with pytest.warns(ParticleSetWarning, match="Some particles are set to be released*"):
ParticleSet(fieldset, pclass=ptype[mode], lon=[0] * len(time), lat=[0] * len(time), time=time)


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_pset_not_multipldt_time(fieldset, mode):
times = [0, 1.1]
Expand Down
4 changes: 2 additions & 2 deletions tests/tools/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
AdvectionRK45,
FieldSet,
FieldSetWarning,
FileWarning,
KernelWarning,
ParticleSet,
ParticleSetWarning,
ScipyParticle,
)
from tests.utils import TEST_DATA
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_file_warnings(tmp_zarrfile):
)
pset = ParticleSet(fieldset=fieldset, pclass=ScipyParticle, lon=[0, 0], lat=[0, 0], time=[0, 1])
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2)
with pytest.warns(FileWarning, match="Some of the particles have a start time difference.*"):
with pytest.warns(ParticleSetWarning, match="Some of the particles have a start time difference.*"):
pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile)


Expand Down
19 changes: 14 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,24 @@ def create_flat_positions(n_particle):
return np.random.rand(n_particle * 3).reshape(3, n_particle)


def create_fieldset_zeros_simple(xdim=40, ydim=100):
U = np.zeros((ydim, xdim), dtype=np.float32)
V = np.zeros((ydim, xdim), dtype=np.float32)
def create_fieldset_zeros_simple(xdim=40, ydim=100, withtime=False):
lon = np.linspace(0, 1, xdim, dtype=np.float32)
lat = np.linspace(-60, 60, ydim, dtype=np.float32)
depth = np.zeros(1, dtype=np.float32)
data = {"U": np.array(U, dtype=np.float32), "V": np.array(V, dtype=np.float32)}
dimensions = {"lat": lat, "lon": lon, "depth": depth}
return FieldSet.from_data(data, dimensions)
if withtime:
tdim = 10
time = np.linspace(0, 86400, tdim, dtype=np.float64)
dimensions["time"] = time
datadims = (tdim, ydim, xdim)
allow_time_extrapolation = False
else:
datadims = (ydim, xdim)
allow_time_extrapolation = True
U = np.zeros(datadims, dtype=np.float32)
V = np.zeros(datadims, dtype=np.float32)
data = {"U": np.array(U, dtype=np.float32), "V": np.array(V, dtype=np.float32)}
return FieldSet.from_data(data, dimensions, allow_time_extrapolation=allow_time_extrapolation)


def assert_empty_folder(path: Path):
Expand Down

0 comments on commit f46b2be

Please sign in to comment.