Skip to content

Commit

Permalink
Merge pull request #1808 from OceanParcels/fix_particlefile_npinf_bug
Browse files Browse the repository at this point in the history
Fixing a bug with default ParticleFile outputdt value
  • Loading branch information
erikvansebille authored Jan 6, 2025
2 parents 0742868 + 35db303 commit b73b638
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ParticleFile:
ParticleFile object that can be used to write particle data to file
"""

def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True):
def __init__(self, name, particleset, outputdt, chunks=None, create_new_zarrfile=True):
self._outputdt = timedelta_to_float(outputdt)
self._chunks = chunks
self._particleset = particleset
Expand Down Expand Up @@ -360,7 +360,7 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N
if len(once_ids) > 0:
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
else:
if max(obs) >= Z[varout].shape[1]:
if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var]
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=1)
Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def fieldset():
def test_metadata(fieldset, mode, tmp_zarrfile):
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=0, lat=0)

pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile))
pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=1))

ds = xr.open_zarr(tmp_zarrfile)
assert ds.attrs["parcels_kernels"].lower() == f"{mode}ParticleDoNothing".lower()
Expand All @@ -47,7 +47,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset, mode):
npart = 10
zarr_store = MemoryStore()
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
pfile = pset.ParticleFile(zarr_store)
pfile = pset.ParticleFile(zarr_store, outputdt=1)
pfile.write(pset, 0)

ds = xr.open_zarr(zarr_store)
Expand All @@ -59,7 +59,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset, mode):
def test_pfile_array_remove_particles(fieldset, mode, tmp_zarrfile):
npart = 10
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
pfile = pset.ParticleFile(tmp_zarrfile)
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
pfile.write(pset, 0)
pset.remove_indices(3)
for p in pset:
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmp_zarrfi
npart = 10
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
chunks = (npart, chunks_obs) if chunks_obs else None
pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks)
pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=1)
pfile.write(pset, 0)
for _ in range(npart):
pset.remove_indices(-1)
Expand Down

0 comments on commit b73b638

Please sign in to comment.