Skip to content

Commit

Permalink
Merge pull request #594 from DHI/timeselector
Browse files Browse the repository at this point in the history
Extract class time step selection
  • Loading branch information
ecomodeller authored Oct 11, 2023
2 parents c80f383 + b1eb162 commit 47882c2
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 111 deletions.
69 changes: 69 additions & 0 deletions mikeio/_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations
from datetime import datetime
from dataclasses import dataclass
from typing import List, Iterable, Optional

import pandas as pd


@dataclass
class DateTimeSelector:
"""Helper class for selecting time steps from a pandas DatetimeIndex"""

index: pd.DatetimeIndex

def isel(
self,
x: Optional[
int | Iterable[int] | str | datetime | pd.DatetimeIndex | slice
] = None,
) -> List[int]:
"""Select time steps from a pandas DatetimeIndex
Parameters
----------
x : int, Iterable[int], str, datetime, pd.DatetimeIndex, slice
Time steps to select, negative indices are supported
Returns
-------
List[int]
List of indices in the range (0, len(index)
Examples
--------
>>> idx = pd.date_range("2000-01-01", periods=4, freq="D")
>>> dts = DateTimeSelector(idx)
>>> dts.isel(None)
[0, 1, 2, 3]
>>> dts.isel(0)
[0]
>>> dts.isel(-1)
[3]
"""

indices = list(range(len(self.index)))

if x is None:
return indices

if isinstance(x, int):
return [indices[x]]

if isinstance(x, (datetime, str)):
loc = self.index.get_loc(x)
if isinstance(loc, int):
return [loc]
elif isinstance(loc, slice):
return list(range(loc.start, loc.stop))

if isinstance(x, slice):
if isinstance(x.start, int) or isinstance(x.stop, int):
return indices[x]
else:
s = self.index.slice_indexer(x.start, x.stop)
return list(range(s.start, s.stop))

if isinstance(x, Iterable):
return [self.isel(t)[0] for t in x]

return indices
47 changes: 11 additions & 36 deletions mikeio/dataset/_data_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations
import re
from datetime import datetime
from typing import Iterable, Sequence, Sized, Tuple
from typing import Iterable, Sequence, Sized, Tuple, Union, List

import numpy as np
import pandas as pd

from .._time import DateTimeSelector


def _to_safe_name(name: str) -> str:
tmp = re.sub("[^0-9a-zA-Z]", "_", name)
Expand All @@ -18,45 +19,19 @@ def _n_selected_timesteps(x: Sized, k: slice | Sized) -> int:
return len(k)


def _get_time_idx_list(time: pd.DatetimeIndex, steps):
def _get_time_idx_list(time: pd.DatetimeIndex, steps) -> Union [List[int], slice]:
"""Find list of idx in DatetimeIndex"""

if isinstance(steps, str):
parts = steps.split(",")
if len(parts) == 1:
parts.append(parts[0]) # end=start

if parts[0] == "":
steps = slice(parts[1]) # stop only
elif parts[1] == "":
steps = slice(parts[0], None) # start only
else:
steps = slice(parts[0], parts[1])
# indexing with a slice needs to be handled differently, since slicing returns a view

if isinstance(steps, (list, tuple)) and isinstance(
steps[0], (str, datetime, np.datetime64, pd.Timestamp)
):
steps = pd.DatetimeIndex(steps)
if isinstance(steps, pd.DatetimeIndex):
return time.get_indexer(steps)
if isinstance(steps, (str, datetime, np.datetime64, pd.Timestamp)):
steps = slice(steps, steps)
if isinstance(steps, slice):
try:
s = time.slice_indexer(
steps.start,
steps.stop,
)
steps = list(range(s.start, s.stop))
except TypeError:
pass # TODO this seems fishy!
# steps = list(range(*steps.indices(len(time))))
elif isinstance(steps, int):
steps = [steps]
# TODO what is the return type of this function
return steps
if isinstance(steps.start, int) and isinstance(steps.stop, int):
return steps

dts = DateTimeSelector(time)
return dts.isel(steps)

# TODO this only used by DataArray, so consider to move it there
class DataUtilsMixin:
"""DataArray Utils"""

Expand Down Expand Up @@ -107,7 +82,7 @@ def _set_by_boolean_mask(data: np.ndarray, mask: np.ndarray, value) -> None:
def _parse_time(time) -> pd.DatetimeIndex:
"""Allow anything that we can create a DatetimeIndex from"""
if time is None:
time = [pd.Timestamp(2018, 1, 1)]
time = [pd.Timestamp(2018, 1, 1)] # TODO is this the correct epoch?
if isinstance(time, str) or (not isinstance(time, Iterable)):
time = [time]

Expand Down
121 changes: 50 additions & 71 deletions mikeio/dfs/_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from abc import abstractmethod
from datetime import datetime
from typing import Iterable, List, Optional, Tuple, Sequence
from typing import List, Optional, Tuple, Sequence
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
Expand All @@ -23,6 +23,7 @@
from ..eum import EUMType, EUMUnit, ItemInfo, ItemInfoList, TimeStepUnit
from ..exceptions import DataDimensionMismatch, ItemsError
from ..spatial import GeometryUndefined
from .._time import DateTimeSelector


def _read_item_time_step(
Expand Down Expand Up @@ -77,10 +78,10 @@ def _valid_item_numbers(
n_items_file = len(dfsItemInfo) - start_idx
if items is None:
return list(range(n_items_file))

# Handling scalar and sequences is a bit tricky
item_numbers : List[int] = []

item_numbers: List[int] = []

# check if items is a scalar (int or str)
if isinstance(items, (int, str)):
Expand All @@ -89,13 +90,13 @@ def _valid_item_numbers(
dfsItemInfo=dfsItemInfo, search=items, start_idx=start_idx
)
elif isinstance(items, str):
item_number = _item_numbers_by_name(dfsItemInfo, [items], ignore_first)[0]
item_number = _item_numbers_by_name(dfsItemInfo, [items], ignore_first)[0]
return [item_number]
elif isinstance(items, int):
if (items < 0) or (items >= n_items_file):
raise ItemsError(n_items_file)
return [items]

assert isinstance(items, Sequence)
for item in items:
if isinstance(item, str):
Expand All @@ -116,78 +117,48 @@ def _valid_item_numbers(

def _valid_timesteps(dfsFileInfo: DfsFileInfo, time_steps) -> Tuple[bool, List[int]]:

time_axis = dfsFileInfo.TimeAxis

single_time_selected = False
if isinstance(time_steps, int) and np.isscalar(time_steps):
if isinstance(time_steps, (int, datetime)):
single_time_selected = True

n_steps_file = dfsFileInfo.TimeAxis.NumberOfTimeSteps
if time_steps is None:
return single_time_selected, list(range(n_steps_file))
nt = time_axis.NumberOfTimeSteps

if isinstance(time_steps, int):
time_steps = [time_steps]
if time_axis.TimeAxisType != TimeAxisType.CalendarEquidistant:
# TODO is this the proper epoch, should this magic number be somewhere else?
start_time_file = datetime(1970, 1, 1)
else:
start_time_file = time_axis.StartDateTime

if isinstance(time_steps, str):
parts = time_steps.split(",")
if len(parts) == 1:
parts.append(parts[0]) # end=start

if parts[0] == "":
time_steps = slice(parts[1]) # stop only
elif parts[1] == "":
time_steps = slice(parts[0], None) # start only
else:
time_steps = slice(parts[0], parts[1])
if time_axis.TimeAxisType in (
TimeAxisType.CalendarEquidistant,
TimeAxisType.TimeEquidistant,
):
time_step_file = time_axis.TimeStep
freq = pd.Timedelta(seconds=time_step_file)
time = pd.date_range(start_time_file, periods=nt, freq=freq)
elif time_axis.TimeAxisType == TimeAxisType.CalendarNonEquidistant:
idx = list(range(nt))

if isinstance(time_steps, (slice, pd.Timestamp, datetime, pd.DatetimeIndex)):
if dfsFileInfo.TimeAxis.TimeAxisType != TimeAxisType.CalendarEquidistant:
# TODO: handle non-equidistant calendar
raise ValueError(
"Only equidistant calendar files are supported for this type of time_step argument"
)
if isinstance(time_steps, int):
return True, [idx[time_steps]]
return single_time_selected, idx

start_time_file = dfsFileInfo.TimeAxis.StartDateTime
time_step_file = dfsFileInfo.TimeAxis.TimeStep
freq = pd.Timedelta(seconds=time_step_file)
time = pd.date_range(start_time_file, periods=n_steps_file, freq=freq)
dts = DateTimeSelector(time)

if isinstance(time_steps, slice):
if isinstance(time_steps.start, int) or isinstance(time_steps.stop, int):
time_steps = list(range(*time_steps.indices(n_steps_file)))
else:
s = time.slice_indexer(time_steps.start, time_steps.stop)
time_steps = list(range(s.start, s.stop))
elif isinstance(time_steps, Sequence) and isinstance(time_steps[0], int):
time_steps = np.array(time_steps)
time_steps[time_steps < 0] = n_steps_file + time_steps[time_steps < 0]
time_steps = list(time_steps)

if max(time_steps) > (n_steps_file - 1):
raise IndexError(f"Timestep cannot be larger than {n_steps_file}")
if min(time_steps) < 0:
raise IndexError(f"Timestep cannot be less than {-n_steps_file}")
elif isinstance(time_steps, Iterable):
steps = []
for t in time_steps:
_, step = _valid_timesteps(dfsFileInfo, t)
steps.append(step[0])
single_time_selected = len(steps) == 1
time_steps = steps

elif isinstance(time_steps, (pd.Timestamp, datetime)):
s = time.slice_indexer(time_steps, time_steps)
time_steps = list(range(s.start, s.stop))
#elif isinstance(time_steps, pd.DatetimeIndex):
# time_steps = list(time.get_indexer(time_steps))
idx = dts.isel(time_steps)

else:
raise TypeError(f"Indexing is not possible with {type(time_steps)}")
if len(time_steps) == 1:
single_time_selected = True
return single_time_selected, time_steps
if isinstance(time_steps, str):
if len(idx) == 1:
single_time_selected = True

return single_time_selected, idx


def _item_numbers_by_name(dfsItemInfo, item_names: List[str], ignore_first: bool=False) -> List[int]:
def _item_numbers_by_name(
dfsItemInfo, item_names: List[str], ignore_first: bool = False
) -> List[int]:
"""Utility function to find item numbers
Parameters
Expand Down Expand Up @@ -243,7 +214,9 @@ def _get_item_info(
item_numbers = list(range(len(dfsItemInfo) - first_idx))

item_numbers = [i + first_idx for i in item_numbers]
items = [ItemInfo.from_mikecore_dynamic_item_info(dfsItemInfo[i]) for i in item_numbers]
items = [
ItemInfo.from_mikecore_dynamic_item_info(dfsItemInfo[i]) for i in item_numbers
]
return ItemInfoList(items)


Expand Down Expand Up @@ -387,7 +360,9 @@ def _read_header(self):
}:
self._start_time = dfs.FileInfo.TimeAxis.StartDateTime
else: # relative time axis
self._start_time = datetime(1970, 1, 1)
self._start_time = datetime(
1970, 1, 1
) # TODO is this the proper epoch, should this magic number be somewhere else?
if hasattr(dfs.FileInfo.TimeAxis, "TimeStep"):
self._timestep_in_seconds = (
dfs.FileInfo.TimeAxis.TimeStep
Expand Down Expand Up @@ -502,7 +477,9 @@ def append(self, data: Dataset) -> None:
darray = d.reshape(d.size, 1)[:, 0]

if self._ndim == 3:
raise NotImplementedError("Append is not yet available for 3D files")
raise NotImplementedError(
"Append is not yet available for 3D files"
)

if self._is_equidistant:
self._dfs.WriteItemTimeStepNext(0, darray.astype(np.float32))
Expand Down Expand Up @@ -717,7 +694,9 @@ def time(self) -> pd.DatetimeIndex:
# this will fail if the TimeAxisType is not calendar and equidistant, but that is ok
if not self._is_equidistant:
raise NotImplementedError("Not implemented for non-equidistant files")
return pd.date_range(start=self.start_time, periods=self.n_timesteps, freq=f"{self.timestep}S")
return pd.date_range(
start=self.start_time, periods=self.n_timesteps, freq=f"{self.timestep}S"
)

@property
def projection_string(self):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,13 @@ def test_read_dfs_time_slice_str():
assert dsr.shape == dsgetitem.shape


def test_read_dfs_time_selection_str_comma():
def test_read_dfs_time_selection_str_slice():

extensions = ["dfs0", "dfs2", "dfs1", "dfs0"]
for ext in extensions:
filename = f"tests/testdata/consistency/oresundHD.{ext}"
time = "2018-03-08,2018-03-10"

time = slice("2018-03-08","2018-03-10")
ds = mikeio.read(filename=filename)
dssel = ds.sel(time=time)
assert dssel.n_timesteps == 3
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,7 @@ def test_time_selection():

assert das_t.shape == (24,)

with pytest.raises(IndexError):
with pytest.raises(KeyError):
# not in time
ds.sel(time="1997-09-15 00:00")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dfsu_layered.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_read_column_interp_time_and_select_time():
salinity_st = da.sel(time="1997-09-15 23:00") # single time-step
assert salinity_st.n_timesteps == 1

with pytest.raises(IndexError):
with pytest.raises(KeyError):
# not in time
da.sel(time="1997-09-15 00:00")

Expand Down
Loading

0 comments on commit 47882c2

Please sign in to comment.