diff --git a/tests/instruments/test_drifter.py b/tests/instruments/test_drifter.py new file mode 100644 index 00000000..e239b2a6 --- /dev/null +++ b/tests/instruments/test_drifter.py @@ -0,0 +1,86 @@ +"""Test the simulation of drifters.""" + +import datetime +from datetime import timedelta + +import numpy as np +import py +import xarray as xr +from parcels import FieldSet + +from virtual_ship import Location, Spacetime +from virtual_ship.instruments.drifter import Drifter, simulate_drifters + + +def test_simulate_drifters(tmpdir: py.path.LocalPath) -> None: + # arbitrary time offset for the dummy fieldset + base_time = datetime.datetime.strptime("1950-01-01", "%Y-%m-%d") + + CONST_TEMPERATURE = 1.0 # constant temperature in fieldset + + v = np.full((2, 2, 2), 1.0) + u = np.full((2, 2, 2), 1.0) + t = np.full((2, 2, 2), CONST_TEMPERATURE) + + fieldset = FieldSet.from_data( + {"V": v, "U": u, "T": t}, + { + "lon": np.array([0.0, 10.0]), + "lat": np.array([0.0, 10.0]), + "time": [ + np.datetime64(base_time + datetime.timedelta(seconds=0)), + np.datetime64(base_time + datetime.timedelta(hours=4)), + ], + }, + ) + + # drifters to deploy + drifters = [ + Drifter( + spacetime=Spacetime( + location=Location(latitude=0, longitude=0), + time=base_time + datetime.timedelta(days=0), + ), + depth=0.0, + lifetime=datetime.timedelta(hours=2), + ), + Drifter( + spacetime=Spacetime( + location=Location(latitude=1, longitude=1), + time=base_time + datetime.timedelta(hours=1), + ), + depth=0.0, + lifetime=None, + ), + ] + + # perform simulation + out_path = tmpdir.join("out.zarr") + + simulate_drifters( + fieldset=fieldset, + out_path=out_path, + drifters=drifters, + outputdt=timedelta(hours=1), + dt=timedelta(minutes=5), + endtime=None, + ) + + # test if output is as expected + results = xr.open_zarr(out_path) + + assert len(results.trajectory) == len(drifters) + + for drifter_i, traj in enumerate(results.trajectory): + # Check if drifters are moving + # lat, lon, should be increasing values (with the above positive VU fieldset) + assert np.all( + np.diff(results.sel(trajectory=traj)["lat"].values) > 0 + ), f"Drifter is not moving over y {drifter_i=}" + assert np.all( + np.diff(results.sel(trajectory=traj)["lon"].values) > 0 + ), f"Drifter is not mvoing over x {drifter_i=}" + + assert np.all( + results.sel(trajectory=traj)["temperature"] == CONST_TEMPERATURE + ), f"measured temperature does not match {drifter_i=}" diff --git a/tests/instruments/test_drifters.py b/tests/instruments/test_drifters.py deleted file mode 100644 index 421499d6..00000000 --- a/tests/instruments/test_drifters.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Test the simulation of drifters.""" - -from datetime import timedelta - -import numpy as np -from parcels import FieldSet - -from virtual_ship import Location, Spacetime -from virtual_ship.instruments.drifter import Drifter, simulate_drifters - - -def test_simulate_drifters() -> None: - fieldset = FieldSet.from_data( - {"U": 0, "V": 0, "T": 0}, - { - "lon": 0, - "lat": 0, - "time": [np.datetime64("1950-01-01") + np.timedelta64(632160, "h")], - }, - ) - - min_depth = -fieldset.U.depth[0] - - drifters = [ - Drifter( - spacetime=Spacetime(location=Location(latitude=0, longitude=0), time=0), - min_depth=min_depth, - ) - ] - - simulate_drifters( - drifters=drifters, - fieldset=fieldset, - out_file_name="test", - outputdt=timedelta(minutes=5), - ) diff --git a/tests/test_sailship.py b/tests/test_sailship.py index 5d37f91c..6cbe6b0d 100644 --- a/tests/test_sailship.py +++ b/tests/test_sailship.py @@ -9,10 +9,7 @@ from virtual_ship.virtual_ship_configuration import VirtualShipConfiguration -def _make_ctd_fieldset() -> FieldSet: - # arbitrary time offset for the dummy fieldset - base_time = datetime.datetime.strptime("2022-01-01T00:00:00", "%Y-%m-%dT%H:%M:%S") - +def _make_ctd_fieldset(base_time: datetime) -> FieldSet: u = np.zeros((2, 2, 2, 2)) v = np.zeros((2, 2, 2, 2)) t = np.zeros((2, 2, 2, 2)) @@ -34,7 +31,29 @@ def _make_ctd_fieldset() -> FieldSet: return fieldset +def _make_drifter_fieldset(base_time: datetime) -> FieldSet: + v = np.full((2, 2, 2), 1.0) + u = np.full((2, 2, 2), 1.0) + t = np.full((2, 2, 2), 1.0) + + fieldset = FieldSet.from_data( + {"V": v, "U": u, "T": t}, + { + "time": [ + np.datetime64(base_time + datetime.timedelta(seconds=0)), + np.datetime64(base_time + datetime.timedelta(weeks=10)), + ], + "lat": [-40, 90], + "lon": [-90, 90], + }, + ) + return fieldset + + def test_sailship() -> None: + # arbitrary time offset for the dummy fieldsets + base_time = datetime.datetime.strptime("2022-01-01T00:00:00", "%Y-%m-%dT%H:%M:%S") + adcp_fieldset = FieldSet.from_data( {"U": 0, "V": 0}, {"lon": 0, "lat": 0}, @@ -45,16 +64,9 @@ def test_sailship() -> None: {"lon": 0, "lat": 0}, ) - ctd_fieldset = _make_ctd_fieldset() + ctd_fieldset = _make_ctd_fieldset(base_time) - drifter_fieldset = FieldSet.from_data( - {"U": 0, "V": 0, "T": 0}, - { - "lon": 0, - "lat": 0, - "time": [np.datetime64("1950-01-01") + np.timedelta64(632160, "h")], - }, - ) + drifter_fieldset = _make_drifter_fieldset(base_time) argo_float_fieldset = FieldSet.from_data( {"U": 0, "V": 0, "T": 0, "S": 0}, diff --git a/virtual_ship/instruments/drifter.py b/virtual_ship/instruments/drifter.py index 5bed859f..3a5be7ab 100644 --- a/virtual_ship/instruments/drifter.py +++ b/virtual_ship/instruments/drifter.py @@ -1,9 +1,10 @@ """Drifter instrument.""" from dataclasses import dataclass -from datetime import timedelta +from datetime import datetime, timedelta import numpy as np +import py from parcels import AdvectionRK4, FieldSet, JITParticle, ParticleSet, Variable from ..spacetime import Spacetime @@ -14,12 +15,16 @@ class Drifter: """Configuration for a single Drifter.""" spacetime: Spacetime - min_depth: float + depth: float # depth at which it floats and samples + lifetime: timedelta | None # if none, lifetime is infinite _DrifterParticle = JITParticle.add_variables( [ Variable("temperature", dtype=np.float32, initial=np.nan), + Variable("has_lifetime", dtype=np.int8), # bool + Variable("age", dtype=np.float32, initial=0.0), + Variable("lifetime", dtype=np.float32), ] ) @@ -28,53 +33,72 @@ def _sample_temperature(particle, fieldset, time): particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] -def _check_error(particle, fieldset, time): - if particle.state >= 50: # This captures all Errors - particle.delete() +def _check_lifetime(particle, fieldset, time): + if particle.has_lifetime == 1: + particle.age += particle.dt + if particle.age >= particle.lifetime: + particle.delete() def simulate_drifters( - drifters: list[Drifter], fieldset: FieldSet, - out_file_name: str, + out_path: str | py.path.LocalPath, + drifters: list[Drifter], outputdt: timedelta, + dt: timedelta, + endtime: datetime | None, ) -> None: """ Use parcels to simulate a set of drifters in a fieldset. + :param fieldset: The fieldset to simulate the Drifters in. + :param out_path: The path to write the results to. :param drifters: A list of drifters to simulate. - :param fieldset: The fieldset to simulate the drifters in. - :param out_file_name: The file to write the results to. - :param outputdt: Interval which dictates the update frequency of file output during simulation + :param outputdt: Interval which dictates the update frequency of file output during simulation. + :param dt: Dt for integration. + :param endtime: Stop at this time, or if None, continue until the end of the fieldset or until all drifters ended. If this is earlier than the last drifter ended or later than the end of the fieldset, a warning will be printed. """ - lon = [drifter.spacetime.location.lon for drifter in drifters] - lat = [drifter.spacetime.location.lat for drifter in drifters] - time = [drifter.spacetime.time for drifter in drifters] - # define parcel particles drifter_particleset = ParticleSet( fieldset=fieldset, pclass=_DrifterParticle, - lon=lon, - lat=lat, - depth=[drifter.min_depth for drifter in drifters], - time=time, + lat=[drifter.spacetime.location.lat for drifter in drifters], + lon=[drifter.spacetime.location.lon for drifter in drifters], + depth=[drifter.depth for drifter in drifters], + time=[drifter.spacetime.time for drifter in drifters], + has_lifetime=[1 if drifter.lifetime is not None else 0 for drifter in drifters], + lifetime=[ + 0 if drifter.lifetime is None else drifter.lifetime.total_seconds() + for drifter in drifters + ], ) # define output file for the simulation - out_file = drifter_particleset.ParticleFile( - name=out_file_name, - outputdt=outputdt, - chunks=(1, 500), - ) + out_file = drifter_particleset.ParticleFile(name=out_path, outputdt=outputdt) - # get time when the fieldset ends + # get earliest between fieldset end time and provide end time fieldset_endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) + if endtime is None: + actual_endtime = fieldset_endtime + elif endtime > fieldset_endtime: + print("WARN: Requested end time later than fieldset end time.") + actual_endtime = fieldset_endtime + else: + actual_endtime = np.timedelta64(endtime) # execute simulation drifter_particleset.execute( - [AdvectionRK4, _sample_temperature, _check_error], - endtime=fieldset_endtime, - dt=outputdt, + [AdvectionRK4, _sample_temperature, _check_lifetime], + endtime=actual_endtime, + dt=dt, output_file=out_file, + verbose_progress=False, ) + + # if there are more particles left than the number of drifters with an indefinite endtime, warn the user + if len(drifter_particleset.particledata) > len( + [d for d in drifters if d.lifetime is None] + ): + print( + "WARN: Some drifters had a life time beyond the end time of the fieldset or the requested end time." + ) diff --git a/virtual_ship/sailship.py b/virtual_ship/sailship.py index 8a5670c5..d9dc4a85 100644 --- a/virtual_ship/sailship.py +++ b/virtual_ship/sailship.py @@ -114,7 +114,8 @@ def sailship(config: VirtualShipConfiguration): spacetime=Spacetime( location=route_point, time=time_past.total_seconds() ), - min_depth=-config.drifter_fieldset.U.depth[0], + depth=-config.drifter_fieldset.U.depth[0], + lifetime=timedelta(weeks=4), ) ) drifter_locations_visited = drifter_locations_visited.union(drifters_here) @@ -228,10 +229,12 @@ def sailship(config: VirtualShipConfiguration): print("Simulating drifters") simulate_drifters( - drifters=drifters, + out_path=os.path.join("results", "drifters.zarr"), fieldset=config.drifter_fieldset, - out_file_name=os.path.join("results", "drifters.zarr"), - outputdt=timedelta(minutes=5), + drifters=drifters, + outputdt=timedelta(hours=5), + dt=timedelta(minutes=5), + endtime=None, ) print("Simulating argo floats") diff --git a/virtual_ship/spacetime.py b/virtual_ship/spacetime.py index 03804975..b2798d7e 100644 --- a/virtual_ship/spacetime.py +++ b/virtual_ship/spacetime.py @@ -7,7 +7,6 @@ @dataclass -# TODO I take suggestions for a better name class Spacetime: """A location and time."""