Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/examples/example_nemo_curvilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from glob import glob

import numpy as np
import pytest

import parcels

Expand Down Expand Up @@ -63,6 +64,10 @@ def test_nemo_curvilinear_AA(tmpdir):
run_nemo_curvilinear(outfile, "AA")


@pytest.mark.v4alpha
@pytest.mark.xfail(
reason="The method for checking whether fields are on the same grid is going to change in v4 (i.e., not by looking at the dataFiles attribute)."
)
def test_nemo_3D_samegrid():
"""Test that the same grid is used for U and V in 3D NEMO fields."""
data_folder = parcels.download_example_dataset("NemoNorthSeaORCA025-N006_data")
Expand Down
48 changes: 8 additions & 40 deletions parcels/field.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import collections
import math
import warnings
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, cast

Expand Down Expand Up @@ -160,14 +159,12 @@ def __init__(
allow_time_extrapolation: bool | None = None,
gridindexingtype: GridIndexingType = "nemo",
to_write: bool = False,
**kwargs,
data_full_zdim=None,
):
if not isinstance(name, tuple):
self.name = name
self.filebuffername = name
else:
self.name = name[0]
self.filebuffername = name[1]
self.data = data
if grid:
self._grid = grid
Expand All @@ -187,7 +184,6 @@ def __init__(
self.units = unitconverters_map[self.fieldtype]
else:
raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'")
self._loaded_time_indices: Iterable[int] = [] # type: ignore
if isinstance(interp_method, dict):
if self.name in interp_method:
self.interp_method = interp_method[self.name]
Expand All @@ -214,31 +210,19 @@ def __init__(
self.allow_time_extrapolation = allow_time_extrapolation

self.data = self._reshape(self.data)
self._loaded_time_indices = range(self.grid.tdim)

# Hack around the fact that NaN and ridiculously large values
# propagate in SciPy's interpolators
self.data[np.isnan(self.data)] = 0.0

self._dimensions = kwargs.pop("dimensions", None)
self._dataFiles = kwargs.pop("dataFiles", None)
self._creation_log = kwargs.pop("creation_log", "")

# data_full_zdim is the vertical dimension of the complete field data, ignoring the indices.
# (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid,
# since some datasets do not provide the deeper level of data (which is ignored by the interpolation).
self.data_full_zdim = kwargs.pop("data_full_zdim", None)
self.filebuffers = [None] * 2
if len(kwargs) > 0:
raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"')
self.data_full_zdim = data_full_zdim

def __repr__(self) -> str:
return field_repr(self)

@property
def dimensions(self):
return self._dimensions

@property
def grid(self):
return self._grid
Expand Down Expand Up @@ -286,27 +270,19 @@ def _get_dim_filenames(cls, filenames, dim):
return filenames

@staticmethod
def _collect_timeslices(data_filenames, dimensions, indices):
timeslices = []
dataFiles = []
def _collect_time(data_filenames, dimensions, indices):
time = []
for fname in data_filenames:
with NetcdfFileBuffer(fname, dimensions, indices) as filebuffer:
ftime = filebuffer.time
timeslices.append(ftime)
dataFiles.append([fname] * len(ftime))
time = np.concatenate(timeslices).ravel()
dataFiles = np.concatenate(dataFiles).ravel()
time.append(ftime)
time = np.concatenate(time).ravel()
if time.size == 1 and time[0] is None:
time[0] = 0
time_origin = TimeConverter(time[0])
time = time_origin.reltime(time)

if not np.all((time[1:] - time[:-1]) > 0):
id_not_ordered = np.where(time[1:] < time[:-1])[0][0]
raise AssertionError(
f"Please make sure your netCDF files are ordered in time. First pair of non-ordered files: {dataFiles[id_not_ordered]}, {dataFiles[id_not_ordered + 1]}"
)
return time, time_origin, timeslices, dataFiles
return time, time_origin

@classmethod
def from_netcdf(
Expand Down Expand Up @@ -430,17 +406,11 @@ def from_netcdf(
# Concatenate time variable to determine overall dimension
# across multiple files
if "time" in dimensions:
time, time_origin, timeslices, dataFiles = cls._collect_timeslices(data_filenames, dimensions, indices)
time, time_origin = cls._collect_time(data_filenames, dimensions, indices)
grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
kwargs["dataFiles"] = dataFiles
else: # e.g. for the CROCO CS_w field, see https://github.com/OceanParcels/Parcels/issues/1831
grid = Grid.create_grid(lon, lat, depth, np.array([0.0]), time_origin=TimeConverter(0.0), mesh=mesh)
data_filenames = [data_filenames[0]]
elif grid is not None and ("dataFiles" not in kwargs or kwargs["dataFiles"] is None):
# ==== means: the field has a shared grid, but may have different data files, so we need to collect the
# ==== correct file time series again.
_, _, _, dataFiles = cls._collect_timeslices(data_filenames, dimensions, indices)
kwargs["dataFiles"] = dataFiles

if "time" in indices:
warnings.warn(
Expand Down Expand Up @@ -473,8 +443,6 @@ def from_netcdf(
if allow_time_extrapolation is None:
allow_time_extrapolation = False if "time" in dimensions else True

kwargs["dimensions"] = dimensions.copy()

return cls(
variable,
data,
Expand Down
49 changes: 1 addition & 48 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def from_data(
else:
time_origin = kwargs.pop("time_origin", TimeConverter(0))
grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_data"

fields[name] = Field(
name,
Expand Down Expand Up @@ -180,7 +178,6 @@ def add_field(self, field: Field, name: str | None = None):
else:
setattr(self, name, field)
self.gridset.add_grid(field)
field.fieldset = self

def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"):
"""Wrapper function to add a Field that is constant in space,
Expand Down Expand Up @@ -214,7 +211,6 @@ def add_vector_field(self, vfield):
for v in vfield.__dict__.values():
if isinstance(v, Field) and (v not in self.get_fields()):
self.add_field(v)
vfield.fieldset = self

def _add_UVfield(self):
if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"):
Expand Down Expand Up @@ -266,9 +262,6 @@ def check_velocityfields(U, V, W):
g._time_origin = self.time_origin
self._add_UVfield()

for f in self.get_fields():
if isinstance(f, VectorField) or f._dataFiles is None:
continue
self._completed = True

@classmethod
Expand Down Expand Up @@ -351,8 +344,6 @@ def from_netcdf(

"""
fields: dict[str, Field] = {}
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_netcdf"
for var, name in variables.items():
# Resolve all matching paths for the current variable
paths = filenames[var] if type(filenames) is dict and var in filenames else filenames
Expand All @@ -368,28 +359,7 @@ def from_netcdf(
fieldtype = fieldtype[var] if (fieldtype and var in fieldtype) else fieldtype

grid = None
dFiles = None
# check if grid has already been processed (i.e. if other fields have same filenames, dimensions and indices)
for procvar, _ in fields.items():
procdims = dimensions[procvar] if procvar in dimensions else dimensions
nowpaths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames
if procdims == dims:
possibly_samegrid = True
if not possibly_samegrid:
break
processedGrid = False
if (not isinstance(filenames, dict)) or filenames[procvar] == filenames[var]:
processedGrid = True
elif isinstance(filenames[procvar], dict):
processedGrid = True
for dim in ["lon", "lat", "depth"]:
if dim in dimensions:
processedGrid *= filenames[procvar][dim] == filenames[var][dim]
if processedGrid:
grid = fields[procvar].grid
if filenames == nowpaths:
dFiles = fields[procvar]._dataFiles
break

fields[var] = Field.from_netcdf(
paths,
(var, name),
Expand All @@ -398,7 +368,6 @@ def from_netcdf(
mesh=mesh,
allow_time_extrapolation=allow_time_extrapolation,
fieldtype=fieldtype,
dataFiles=dFiles,
**kwargs,
)

Expand Down Expand Up @@ -481,8 +450,6 @@ def from_nemo(
Keyword arguments passed to the :func:`Fieldset.from_c_grid_dataset` constructor.

"""
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_nemo"
if kwargs.pop("gridindexingtype", "nemo") != "nemo":
raise ValueError(
"gridindexingtype must be 'nemo' in FieldSet.from_nemo(). Use FieldSet.from_c_grid_dataset otherwise"
Expand Down Expand Up @@ -525,8 +492,6 @@ def from_mitgcm(
For indexing details: https://mitgcm.readthedocs.io/en/latest/algorithm/algorithm.html#spatial-discretization-of-the-dynamical-equations
Note that vertical velocity (W) is assumed positive in the positive z direction (which is upward in MITgcm)
"""
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_mitgcm"
if kwargs.pop("gridindexingtype", "mitgcm") != "mitgcm":
raise ValueError(
"gridindexingtype must be 'mitgcm' in FieldSet.from_mitgcm(). Use FieldSet.from_c_grid_dataset otherwise"
Expand Down Expand Up @@ -568,8 +533,6 @@ def from_croco(

See `the CROCO 3D tutorial <../examples/tutorial_croco_3D.ipynb>`__ for more infomation.
"""
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_croco"
if kwargs.pop("gridindexingtype", "croco") != "croco":
raise ValueError(
"gridindexingtype must be 'croco' in FieldSet.from_croco(). Use FieldSet.from_c_grid_dataset otherwise"
Expand Down Expand Up @@ -717,8 +680,6 @@ def from_c_grid_dataset(
interp_method[v] = "cgrid_velocity"
else:
interp_method[v] = tracer_interp_method
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_c_grid_dataset"

return cls.from_netcdf(
filenames,
Expand Down Expand Up @@ -800,8 +761,6 @@ def from_mom5(
**kwargs :
Keyword arguments passed to the :func:`Fieldset.from_b_grid_dataset` constructor.
"""
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_mom5"
fieldset = cls.from_b_grid_dataset(
filenames,
variables,
Expand Down Expand Up @@ -924,8 +883,6 @@ def from_b_grid_dataset(
interp_method[v] = "bgrid_w_velocity"
else:
interp_method[v] = tracer_interp_method
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_b_grid_dataset"

return cls.from_netcdf(
filenames,
Expand Down Expand Up @@ -972,8 +929,6 @@ def from_parcels(
"""
if extra_fields is None:
extra_fields = {}
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_parcels"

dimensions = {}
default_dims = {"lon": "nav_lon", "lat": "nav_lat", "depth": "depth", "time": "time_counter"}
Expand Down Expand Up @@ -1025,8 +980,6 @@ def from_xarray_dataset(cls, ds, variables, dimensions, mesh="spherical", allow_
Keyword arguments passed to the :func:`Field.from_xarray` constructor.
"""
fields = {}
if "creation_log" not in kwargs.keys():
kwargs["creation_log"] = "from_xarray_dataset"

for var, name in variables.items():
dims = dimensions[var] if var in dimensions else dimensions
Expand Down
13 changes: 0 additions & 13 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,12 @@ def test_fieldset_from_data(xdim, ydim):
"""Simple test for fieldset initialisation from data."""
data, dimensions = generate_fieldset_data(xdim, ydim)
fieldset = FieldSet.from_data(data, dimensions)
assert fieldset.U._creation_log == "from_data"
assert len(fieldset.U.data.shape) == 3
assert len(fieldset.V.data.shape) == 3
assert np.allclose(fieldset.U.data[0, :], data["U"], rtol=1e-12)
assert np.allclose(fieldset.V.data[0, :], data["V"], rtol=1e-12)


def test_fieldset_extra_syntax():
"""Simple test for fieldset initialisation from data."""
data, dimensions = generate_fieldset_data(10, 10)

with pytest.raises(SyntaxError):
FieldSet.from_data(data, dimensions, unknown_keyword=5)


@pytest.mark.v4remove
@pytest.mark.xfail(reason="vmin and vmax were removed as arguments")
def test_fieldset_vmin_vmax():
Expand Down Expand Up @@ -172,7 +163,6 @@ def test_fieldset_from_modulefile():
nemo_error_fname = str(TEST_DATA / "fieldset_nemo_error.py")

fieldset = FieldSet.from_modulefile(nemo_fname)
assert fieldset.U._creation_log == "from_nemo"

fieldset = FieldSet.from_modulefile(nemo_fname)
assert fieldset.U.grid.lon.shape[1] == 21
Expand Down Expand Up @@ -379,7 +369,6 @@ def test_fieldset_write_curvilinear(tmpdir):
variables = {"dx": "e1u"}
dimensions = {"lon": "glamu", "lat": "gphiu"}
fieldset = FieldSet.from_nemo(filenames, variables, dimensions)
assert fieldset.dx._creation_log == "from_nemo"

newfile = tmpdir.join("curv_field")
fieldset.write(newfile)
Expand All @@ -389,7 +378,6 @@ def test_fieldset_write_curvilinear(tmpdir):
variables={"dx": "dx"},
dimensions={"time": "time_counter", "depth": "depthdx", "lon": "nav_lon", "lat": "nav_lat"},
)
assert fieldset2.dx._creation_log == "from_netcdf"

for var in ["lon", "lat", "data"]:
assert np.allclose(getattr(fieldset2.dx, var), getattr(fieldset.dx, var))
Expand Down Expand Up @@ -648,7 +636,6 @@ def generate_dataset(xdim, ydim, zdim=1, tdim=1):
else:
dimensions = {"lat": "lat", "lon": "lon", "depth": "depth"}
fieldset = FieldSet.from_xarray_dataset(ds, variables, dimensions, mesh="flat")
assert fieldset.U._creation_log == "from_xarray_dataset"

pset = ParticleSet(fieldset, Particle, 0, 0, depth=20)

Expand Down
Loading