Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Mar 1, 2024
1 parent 8945a41 commit 60d2bf6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 204 deletions.
199 changes: 42 additions & 157 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def _make_slice_or_index_from_list_or_tuple(indices):

step = indices[1] - indices[0]

if step > 0 and all(
indices[i] - indices[i - 1] == step for i in range(1, len(indices))
):
if step > 0 and all(indices[i] - indices[i - 1] == step for i in range(1, len(indices))):
return slice(indices[0], indices[-1] + step, step)

return indices
Expand All @@ -85,10 +83,6 @@ class MissingDate(Exception):
pass


class InvalidSlice(Exception):
pass


class Dataset:
arguments = {}

Expand Down Expand Up @@ -130,12 +124,6 @@ def _subset(self, **kwargs):
statistics = kwargs.pop("statistics")
return Statistics(self, statistics)._subset(**kwargs)

if "with_valid_ranges" in kwargs:
if not self.missing:
return self
width = kwargs.pop("with_valid_ranges")
return ValidRanges(self, width)._subset(**kwargs)

raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))

def _frequency_to_indices(self, frequency):
Expand Down Expand Up @@ -239,30 +227,11 @@ def __repr__(self):
@debug_indexing
@expand_list_indexing
def _get_tuple(self, n):
raise NotImplementedError(
f"Tuple not supported: {n} (class {self.__class__.__name__})"
)

def valid_ranges(self, width):
# TODO: optimize, by implementing in subclasses
result = []
n = 0
missing = self.missing
while n < len(self):
if n in missing:
n += 1
continue

m = n
while m < len(self) and m not in missing:
m += 1

if m - n >= width:
result.append((n, m))

n = m
raise NotImplementedError(f"Tuple not supported: {n} (class {self.__class__.__name__})")

return result
@property
def grids(self):
return (self.shape[-1],)


class Source:
Expand All @@ -280,9 +249,7 @@ def __repr__(self):
p = s
s = s.source

return (
f"{self.dataset}[{self.index}, {self.dataset.variables[self.index]}] ({p})"
)
return f"{self.dataset}[{self.index}, {self.dataset.variables[self.index]}] ({p})"

def target(self):
p = s = self.source
Expand Down Expand Up @@ -539,9 +506,7 @@ def __init__(self, path):

missing_dates = self.z.attrs.get("missing_dates", [])
missing_dates = [np.datetime64(x) for x in missing_dates]
self.missing_to_dates = {
i: d for i, d in enumerate(self.dates) if d in missing_dates
}
self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates}
self.missing = set(self.missing_to_dates)

def mutate(self):
Expand Down Expand Up @@ -640,6 +605,10 @@ def dtype(self):
def missing(self):
return self.forward.missing

@property
def grids(self):
return self.forward.grids

def metadata_specific(self, **kwargs):
return super().metadata_specific(
forward=self.forward.metadata_specific(),
Expand All @@ -660,47 +629,33 @@ def __init__(self, datasets):

def check_same_resolution(self, d1, d2):
if d1.resolution != d2.resolution:
raise ValueError(
f"Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})"
)
raise ValueError(f"Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})")

def check_same_frequency(self, d1, d2):
if d1.frequency != d2.frequency:
raise ValueError(
f"Incompatible frequencies: {d1.frequency} and {d2.frequency} ({d1} {d2})"
)
raise ValueError(f"Incompatible frequencies: {d1.frequency} and {d2.frequency} ({d1} {d2})")

def check_same_grid(self, d1, d2):
if (d1.latitudes != d2.latitudes).any() or (
d1.longitudes != d2.longitudes
).any():
if (d1.latitudes != d2.latitudes).any() or (d1.longitudes != d2.longitudes).any():
raise ValueError(f"Incompatible grid ({d1} {d2})")

def check_same_shape(self, d1, d2):
if d1.shape[1:] != d2.shape[1:]:
raise ValueError(
f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})"
)
raise ValueError(f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})")

if d1.variables != d2.variables:
raise ValueError(
f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})"
)
raise ValueError(f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})")

def check_same_sub_shapes(self, d1, d2, drop_axis):
shape1 = d1.sub_shape(drop_axis)
shape2 = d2.sub_shape(drop_axis)

if shape1 != shape2:
raise ValueError(
f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})"
)
raise ValueError(f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})")

def check_same_variables(self, d1, d2):
if d1.variables != d2.variables:
raise ValueError(
f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})"
)
raise ValueError(f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})")

def check_same_lengths(self, d1, d2):
if d1._len != d2._len:
Expand All @@ -710,14 +665,10 @@ def check_same_dates(self, d1, d2):
self.check_same_frequency(d1, d2)

if d1.dates[0] != d2.dates[0]:
raise ValueError(
f"Incompatible start dates: {d1.dates[0]} and {d2.dates[0]} ({d1} {d2})"
)
raise ValueError(f"Incompatible start dates: {d1.dates[0]} and {d2.dates[0]} ({d1} {d2})")

if d1.dates[-1] != d2.dates[-1]:
raise ValueError(
f"Incompatible end dates: {d1.dates[-1]} and {d2.dates[-1]} ({d1} {d2})"
)
raise ValueError(f"Incompatible end dates: {d1.dates[-1]} and {d2.dates[-1]} ({d1} {d2})")

def check_compatibility(self, d1, d2):
# These are the default checks
Expand Down Expand Up @@ -767,11 +718,7 @@ def _get_tuple(self, index):
lengths = [d.shape[0] for d in self.datasets]
slices = length_to_slices(index[0], lengths)
# print("slies", slices)
result = [
d[update_tuple(index, 0, i)[0]]
for (d, i) in zip(self.datasets, slices)
if i is not None
]
result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
result = np.concatenate(result, axis=0)
return apply_index_to_slices_changes(result, changes)

Expand Down Expand Up @@ -853,11 +800,7 @@ def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
lengths = [d.shape[self.axis] for d in self.datasets]
slices = length_to_slices(index[self.axis], lengths)
result = [
d[update_tuple(index, self.axis, i)[0]]
for (d, i) in zip(self.datasets, slices)
if i is not None
]
result = [d[update_tuple(index, self.axis, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
result = np.concatenate(result, axis=self.axis)
return apply_index_to_slices_changes(result, changes)

Expand Down Expand Up @@ -901,6 +844,13 @@ def latitudes(self):
def longitudes(self):
return np.concatenate([d.longitudes for d in self.datasets])

@property
def grids(self):
result = []
for d in self.datasets:
result.extend(d.grids)
return tuple(result)


class CutoutGrids(Grids):
def __init__(self, datasets, axis):
Expand Down Expand Up @@ -953,9 +903,7 @@ def __getitem__(self, index):
@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
assert index[self.axis] == slice(
None
), "No support for selecting a subset of the 1D values"
assert index[self.axis] == slice(None), "No support for selecting a subset of the 1D values"
index, changes = index_to_slices(index, self.shape)

# In case index_to_slices has changed the last slice
Expand All @@ -970,6 +918,14 @@ def _get_tuple(self, index):

return apply_index_to_slices_changes(result, changes)

@property
def grids(self):
for d in self.datasets:
if len(d.grids) > 1:
raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs")
shape = self.lam.shape
return (shape[-1], self.shape[-1] - shape[-1])


class Join(Combined):
"""Join the datasets along the variables axis."""
Expand Down Expand Up @@ -1062,8 +1018,7 @@ def name_to_index(self):
@property
def statistics(self):
return {
k: np.concatenate([d.statistics[k] for d in self.datasets], axis=0)
for k in self.datasets[0].statistics
k: np.concatenate([d.statistics[k] for d in self.datasets], axis=0) for k in self.datasets[0].statistics
}

def source(self, index):
Expand Down Expand Up @@ -1160,71 +1115,6 @@ def missing(self):
return {self.indices[i] for i in self.dataset.missing if i in self.indices}


class ValidRanges(Subset):
def __init__(self, dataset, width):
self.width = width
self.ranges = dataset.valid_ranges(width)

super().__init__(dataset, [i for r in self.ranges for i in range(*r)])

assert not self.missing, self.missing

@debug_indexing
def _get_slice(self, s):
self._check_slice(s)
return super()._get_slice(s)

@debug_indexing
@expand_list_indexing
def _get_tuple(self, n):
if isinstance(n[0], int):
return super()._get_tuple(n)

if isinstance(n[0], slice):
self._check_slice(n[0])
return super()._get_tuple(n)

if isinstance(n[0], (list, tuple)):
self._check_list(n[0])
return super()._get_tuple(n)

raise TypeError(f"Unsupported index {n} {type(n)}")

def _check_slice(self, s):
"""Check that once mapped in the new index, the slice is regular."""

indices = [self.indices[i] for i in range(*s.indices(self._len))]
self._check_indices(s, indices)

def _check_list(self, s):
indices = [self.indices[i] for i in s]
self._check_indices(s, indices)

def _check_indices(self, s, indices):
indices = _make_slice_or_index_from_list_or_tuple(indices)
if isinstance(indices, slice):
# The slice is regular in the new index, so we do not have
# missing dates
return

dates = self.dataset.dates[indices]
delta = [dates[i] - dates[0] for i in range(1, len(dates))]
delta = [int(x.astype(object).total_seconds() / 3600) for x in delta]
delta = ",".join(["T"] + [f"T{x:+d}" for x in delta])

if isinstance(s, slice):
idx = [str(x) if x is not None else "" for x in (s.start, s.stop, s.step)]
idx = ":".join(idx)
if s.step is None:
idx = idx[:-1]
else:
idx = f"({','.join(str(x) for x in s)})"

raise InvalidSlice(
f"Index [{idx}] does not cover a regular range: [{delta}] (width={self.width})"
)


class Select(Forwards):
"""Select a subset of the variables."""

Expand Down Expand Up @@ -1450,9 +1340,7 @@ def _concat_or_join(datasets, kwargs):
for j in range(i + 1, len(ranges)):
s = ranges[j]
if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
raise ValueError(
f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})"
)
raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")

# For now we should have the datasets in order with no gaps

Expand All @@ -1463,8 +1351,7 @@ def _concat_or_join(datasets, kwargs):
s = ranges[i + 1]
if r[1] + datetime.timedelta(hours=frequency) != s[0]:
raise ValueError(
"Datasets must be sorted by dates, with no gaps: "
f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
"Datasets must be sorted by dates, with no gaps: " f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
)

return Concat(datasets), kwargs
Expand Down Expand Up @@ -1570,9 +1457,7 @@ def _open_dataset(*args, zarr_root, **kwargs):
"cutout": CutoutGrids,
}
if mode not in KLASSES:
raise ValueError(
f"Unknown grids mode: {mode}, values are {list(KLASSES.keys())}"
)
raise ValueError(f"Unknown grids mode: {mode}, values are {list(KLASSES.keys())}")

datasets = [_open(e, zarr_root) for e in grids]
datasets, kwargs = _auto_adjust(datasets, kwargs)
Expand Down
25 changes: 0 additions & 25 deletions scripts/get-grids.py

This file was deleted.

Loading

0 comments on commit 60d2bf6

Please sign in to comment.