Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single timestep dataset forgets dt #692

Merged
merged 3 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions mikeio/dataset/_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ def __call__(self, tail: bool = True) -> "DataArray":
geometry = GeometryUndefined()

return DataArray(
data=Hm0, time=self.da.time, item=item, dims=dims, geometry=geometry
data=Hm0,
time=self.da.time,
item=item,
dims=dims,
geometry=geometry,
dt=self.da._dt,
)


Expand Down Expand Up @@ -162,10 +167,12 @@ def __init__(
geometry: GeometryType | None = None,
zn: np.ndarray | None = None,
dims: Sequence[str] | None = None,
dt: float = 1.0,
) -> None:
# TODO: add optional validation validate=True
self._values = self._parse_data(data)
self.time: pd.DatetimeIndex = self._parse_time(time)
self._dt = dt

geometry = GeometryUndefined() if geometry is None else geometry
self.dims = self._parse_dims(dims, geometry)
Expand Down Expand Up @@ -421,11 +428,11 @@ def is_equidistant(self) -> bool:
return len(self.time.to_series().diff().dropna().unique()) == 1

@property
def timestep(self) -> float | None:
def timestep(self) -> float:
"""Time step in seconds if equidistant (and at
least two time instances); otherwise None
least two time instances); otherwise original time step is returned.
"""
dt = None
dt = self._dt
if len(self.time) > 1 and self.is_equidistant:
first: pd.Timestamp = self.time[0]
second: pd.Timestamp = self.time[1]
Expand Down Expand Up @@ -539,6 +546,7 @@ def squeeze(self) -> "DataArray":
geometry=self.geometry,
zn=self._zn,
dims=tuple(dims),
dt=self._dt,
)

# ============= Select/interp ===========
Expand Down Expand Up @@ -718,6 +726,7 @@ def isel(
geometry=geometry,
zn=zn,
dims=dims,
dt=self._dt,
)

def sel(
Expand Down Expand Up @@ -969,7 +978,11 @@ def interp(
# )

da = DataArray(
data=dai, time=self.time, geometry=geometry, item=deepcopy(self.item)
data=dai,
time=self.time,
geometry=geometry,
item=deepcopy(self.item),
dt=self._dt,
)
else:
da = self.copy()
Expand Down Expand Up @@ -1097,6 +1110,7 @@ def interp_time(
item=deepcopy(self.item),
geometry=self.geometry,
zn=zn,
dt=self._dt,
)

def interp_na(self, axis: str = "time", **kwargs: Any) -> "DataArray":
Expand Down Expand Up @@ -1197,7 +1211,11 @@ def interp_like(
)
assert isinstance(ari, np.ndarray)
dai = DataArray(
data=ari, time=self.time, geometry=geom, item=deepcopy(self.item)
data=ari,
time=self.time,
geometry=geom,
item=deepcopy(self.item),
dt=self._dt,
)

if hasattr(other, "time"):
Expand Down Expand Up @@ -1506,6 +1524,7 @@ def aggregate(
geometry=geometry,
dims=dims,
zn=zn,
dt=self._dt,
)

@overload
Expand Down Expand Up @@ -1599,7 +1618,13 @@ def _quantile(self, q, *, axis: int | str = 0, func=np.quantile, **kwargs: Any):
dims = tuple([d for i, d in enumerate(self.dims) if i != axis])
item = deepcopy(self.item)
return DataArray(
data=qdat, time=time, item=item, geometry=geometry, dims=dims, zn=zn
data=qdat,
time=time,
item=item,
geometry=geometry,
dims=dims,
zn=zn,
dt=self._dt,
)
else:
res = []
Expand Down Expand Up @@ -1747,6 +1772,7 @@ def _boolmask_to_new_DataArray(self, bmask) -> "DataArray": # type: ignore
item=ItemInfo("Boolean"),
geometry=self.geometry,
zn=self._zn,
dt=self._dt,
)

# ============= output methods: to_xxx() ===========
Expand Down
27 changes: 17 additions & 10 deletions mikeio/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ def __init__(
zn: NDArray[np.floating] | None = None,
dims: Tuple[str, ...] | None = None,
validate: bool = True,
dt: float = 1.0,
):
if not self._is_DataArrays(data):
data = self._create_dataarrays(
data=data, time=time, items=items, geometry=geometry, zn=zn, dims=dims # type: ignore
data=data, time=time, items=items, geometry=geometry, zn=zn, dims=dims, dt=dt # type: ignore
)
self._data_vars: MutableMapping[str, DataArray] = self._init_from_DataArrays(data, validate=validate) # type: ignore
self.plot = _DatasetPlotter(self)
Expand All @@ -123,11 +124,12 @@ def _is_DataArrays(data: Any) -> bool:
@staticmethod
def _create_dataarrays(
data: Sequence[NDArray[np.floating]] | NDArray[np.floating],
time: pd.DatetimeIndex | None = None,
items: Sequence[ItemInfo] | None = None,
geometry: Any = None,
zn: NDArray[np.floating] | None = None,
dims: Tuple[str, ...] | None = None,
time: pd.DatetimeIndex,
items: Sequence[ItemInfo],
geometry: Any,
zn: NDArray[np.floating],
dims: Tuple[str, ...],
dt: float,
) -> Mapping[str, DataArray]:
if not isinstance(data, Iterable):
data = [data]
Expand All @@ -137,7 +139,7 @@ def _create_dataarrays(
data_vars = {}
for dd, it in zip(data, items):
data_vars[it.name] = DataArray(
data=dd, time=time, item=it, geometry=geometry, zn=zn, dims=dims
data=dd, time=time, item=it, geometry=geometry, zn=zn, dims=dims, dt=dt
)
return data_vars

Expand Down Expand Up @@ -303,6 +305,11 @@ def _check_already_present(self, new_da: DataArray) -> None:

# ============= Basic properties/methods ===========

@property
def _dt(self) -> float:
"""Original time step in seconds"""
return self[0]._dt

@property
def time(self) -> pd.DatetimeIndex:
"""Time axis"""
Expand All @@ -326,11 +333,11 @@ def end_time(self) -> datetime:
return self.time[-1].to_pydatetime() # type: ignore

@property
def timestep(self) -> float | None:
def timestep(self) -> float:
"""Time step in seconds if equidistant (and at
least two time instances); otherwise None
least two time instances); otherwise original time step is returned.
"""
dt = None
dt = self._dt
if len(self.time) > 1 and self.is_equidistant:
dt = (self.time[1] - self.time[0]).total_seconds()
return dt
Expand Down
9 changes: 8 additions & 1 deletion mikeio/dfs/_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,14 @@ def read(
items = _get_item_info(self._dfs.ItemInfo, item_numbers)

self._dfs.Close()
return Dataset(data_list, time, items, geometry=self.geometry, validate=False)
return Dataset(
data_list,
time,
items,
geometry=self.geometry,
validate=False,
dt=self._timestep,
)

def _open(self) -> None:
raise NotImplementedError("Should be implemented by subclass")
Expand Down
17 changes: 10 additions & 7 deletions mikeio/dfsu/_dfsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@ def write_dfsu(filename: str | Path, data: Dataset) -> None:
"""
filename = str(filename)

if len(data.time) == 1:
dt = 1 # TODO is there any sensible default?
else:
if not data.is_equidistant:
raise ValueError("Non-equidistant time axis is not supported.")
if not data.is_equidistant:
raise ValueError("Non-equidistant time axis is not supported.")

dt = (data.time[1] - data.time[0]).total_seconds() # type: ignore
dt = data.timestep
n_time_steps = len(data.time)

geometry = data.geometry
Expand Down Expand Up @@ -485,7 +482,13 @@ def read(
data_list = [np.squeeze(d, axis=-1) for d in data_list]

return Dataset(
data_list, time, items, geometry=geometry, dims=dims, validate=False
data_list,
time,
items,
geometry=geometry,
dims=dims,
validate=False,
dt=self.timestep,
)

def _parse_geometry_sel(self, area, x, y):
Expand Down
9 changes: 8 additions & 1 deletion mikeio/dfsu/_layered.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,17 @@ def read(
zn=data_list[0],
dims=dims,
validate=False,
dt=self.timestep,
)
else:
return Dataset(
data_list, time, items, geometry=geometry, dims=dims, validate=False
data_list,
time,
items,
geometry=geometry,
dims=dims,
validate=False,
dt=self.timestep,
)


Expand Down
30 changes: 30 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,3 +1609,33 @@ def test_interp_na():
def test_plot_scatter():
ds = mikeio.read("tests/testdata/oresund_sigma_z.dfsu", time=0)
ds.plot.scatter(x="Salinity", y="Temperature", title="S-vs-T")


def test_select_single_timestep_preserves_dt():
ds = mikeio.read("tests/testdata/tide1.dfs1")
assert ds.timestep == pytest.approx(1800.0)
ds2 = ds.isel(time=-1)
assert ds2.timestep == pytest.approx(1800.0)
assert ds2[0].timestep == pytest.approx(1800.0)


def test_select_multiple_spaced_timesteps_uses_proper_dt(tmp_path):
ds = mikeio.read("tests/testdata/tide1.dfs1")
assert ds.timestep == pytest.approx(1800.0)
ds2 = ds.isel(time=[0, 2, 4])
assert ds2.timestep == pytest.approx(3600.0)


def test_read_write_single_timestep_preserves_dt(tmp_path):
fn = "tests/testdata/oresund_sigma_z.dfsu"
dfs = mikeio.open(fn)
assert dfs.timestep == pytest.approx(10800.0)

ds = dfs.read(time=[0])
assert ds.timestep == pytest.approx(dfs.timestep)

outfn = tmp_path / "single.dfsu"
ds.to_dfs(outfn)

dfs2 = mikeio.open(outfn)
assert dfs2.timestep == pytest.approx(10800.0)
Loading