diff --git a/ecml_tools/data.py b/ecml_tools/data.py index 371b811..af37184 100644 --- a/ecml_tools/data.py +++ b/ecml_tools/data.py @@ -73,9 +73,7 @@ def _make_slice_or_index_from_list_or_tuple(indices): step = indices[1] - indices[0] - if step > 0 and all( - indices[i] - indices[i - 1] == step for i in range(1, len(indices)) - ): + if step > 0 and all(indices[i] - indices[i - 1] == step for i in range(1, len(indices))): return slice(indices[0], indices[-1] + step, step) return indices @@ -85,10 +83,6 @@ class MissingDate(Exception): pass -class InvalidSlice(Exception): - pass - - class Dataset: arguments = {} @@ -130,12 +124,6 @@ def _subset(self, **kwargs): statistics = kwargs.pop("statistics") return Statistics(self, statistics)._subset(**kwargs) - if "with_valid_ranges" in kwargs: - if not self.missing: - return self - width = kwargs.pop("with_valid_ranges") - return ValidRanges(self, width)._subset(**kwargs) - raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs)) def _frequency_to_indices(self, frequency): @@ -239,30 +227,11 @@ def __repr__(self): @debug_indexing @expand_list_indexing def _get_tuple(self, n): - raise NotImplementedError( - f"Tuple not supported: {n} (class {self.__class__.__name__})" - ) - - def valid_ranges(self, width): - # TODO: optimize, by implementing in subclasses - result = [] - n = 0 - missing = self.missing - while n < len(self): - if n in missing: - n += 1 - continue - - m = n - while m < len(self) and m not in missing: - m += 1 - - if m - n >= width: - result.append((n, m)) - - n = m + raise NotImplementedError(f"Tuple not supported: {n} (class {self.__class__.__name__})") - return result + @property + def grids(self): + return (self.shape[-1],) class Source: @@ -280,9 +249,7 @@ def __repr__(self): p = s s = s.source - return ( - f"{self.dataset}[{self.index}, {self.dataset.variables[self.index]}] ({p})" - ) + return f"{self.dataset}[{self.index}, {self.dataset.variables[self.index]}] ({p})" def target(self): p = s = self.source @@ -539,9 +506,7 @@ def __init__(self, path): missing_dates = self.z.attrs.get("missing_dates", []) missing_dates = [np.datetime64(x) for x in missing_dates] - self.missing_to_dates = { - i: d for i, d in enumerate(self.dates) if d in missing_dates - } + self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates} self.missing = set(self.missing_to_dates) def mutate(self): @@ -640,6 +605,10 @@ def dtype(self): def missing(self): return self.forward.missing + @property + def grids(self): + return self.forward.grids + def metadata_specific(self, **kwargs): return super().metadata_specific( forward=self.forward.metadata_specific(), @@ -660,47 +629,33 @@ def __init__(self, datasets): def check_same_resolution(self, d1, d2): if d1.resolution != d2.resolution: - raise ValueError( - f"Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})" - ) + raise ValueError(f"Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})") def check_same_frequency(self, d1, d2): if d1.frequency != d2.frequency: - raise ValueError( - f"Incompatible frequencies: {d1.frequency} and {d2.frequency} ({d1} {d2})" - ) + raise ValueError(f"Incompatible frequencies: {d1.frequency} and {d2.frequency} ({d1} {d2})") def check_same_grid(self, d1, d2): - if (d1.latitudes != d2.latitudes).any() or ( - d1.longitudes != d2.longitudes - ).any(): + if (d1.latitudes != d2.latitudes).any() or (d1.longitudes != d2.longitudes).any(): raise ValueError(f"Incompatible grid ({d1} {d2})") def check_same_shape(self, d1, d2): if d1.shape[1:] != d2.shape[1:]: - raise ValueError( - f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})" - ) + raise ValueError(f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})") if d1.variables != d2.variables: - raise ValueError( - f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})" - ) + raise ValueError(f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})") def check_same_sub_shapes(self, d1, d2, drop_axis): shape1 = d1.sub_shape(drop_axis) shape2 = d2.sub_shape(drop_axis) if shape1 != shape2: - raise ValueError( - f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})" - ) + raise ValueError(f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})") def check_same_variables(self, d1, d2): if d1.variables != d2.variables: - raise ValueError( - f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})" - ) + raise ValueError(f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})") def check_same_lengths(self, d1, d2): if d1._len != d2._len: @@ -710,14 +665,10 @@ def check_same_dates(self, d1, d2): self.check_same_frequency(d1, d2) if d1.dates[0] != d2.dates[0]: - raise ValueError( - f"Incompatible start dates: {d1.dates[0]} and {d2.dates[0]} ({d1} {d2})" - ) + raise ValueError(f"Incompatible start dates: {d1.dates[0]} and {d2.dates[0]} ({d1} {d2})") if d1.dates[-1] != d2.dates[-1]: - raise ValueError( - f"Incompatible end dates: {d1.dates[-1]} and {d2.dates[-1]} ({d1} {d2})" - ) + raise ValueError(f"Incompatible end dates: {d1.dates[-1]} and {d2.dates[-1]} ({d1} {d2})") def check_compatibility(self, d1, d2): # These are the default checks @@ -767,11 +718,7 @@ def _get_tuple(self, index): lengths = [d.shape[0] for d in self.datasets] slices = length_to_slices(index[0], lengths) # print("slies", slices) - result = [ - d[update_tuple(index, 0, i)[0]] - for (d, i) in zip(self.datasets, slices) - if i is not None - ] + result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None] result = np.concatenate(result, axis=0) return apply_index_to_slices_changes(result, changes) @@ -853,11 +800,7 @@ def _get_tuple(self, index): index, changes = index_to_slices(index, self.shape) lengths = [d.shape[self.axis] for d in self.datasets] slices = length_to_slices(index[self.axis], lengths) - result = [ - d[update_tuple(index, self.axis, i)[0]] - for (d, i) in zip(self.datasets, slices) - if i is not None - ] + result = [d[update_tuple(index, self.axis, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None] result = np.concatenate(result, axis=self.axis) return apply_index_to_slices_changes(result, changes) @@ -901,6 +844,13 @@ def latitudes(self): def longitudes(self): return np.concatenate([d.longitudes for d in self.datasets]) + @property + def grids(self): + result = [] + for d in self.datasets: + result.extend(d.grids) + return tuple(result) + class CutoutGrids(Grids): def __init__(self, datasets, axis): @@ -953,9 +903,7 @@ def __getitem__(self, index): @debug_indexing @expand_list_indexing def _get_tuple(self, index): - assert index[self.axis] == slice( - None - ), "No support for selecting a subset of the 1D values" + assert index[self.axis] == slice(None), "No support for selecting a subset of the 1D values" index, changes = index_to_slices(index, self.shape) # In case index_to_slices has changed the last slice @@ -970,6 +918,14 @@ def _get_tuple(self, index): return apply_index_to_slices_changes(result, changes) + @property + def grids(self): + for d in self.datasets: + if len(d.grids) > 1: + raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs") + shape = self.lam.shape + return (shape[-1], self.shape[-1] - shape[-1]) + class Join(Combined): """Join the datasets along the variables axis.""" @@ -1062,8 +1018,7 @@ def name_to_index(self): @property def statistics(self): return { - k: np.concatenate([d.statistics[k] for d in self.datasets], axis=0) - for k in self.datasets[0].statistics + k: np.concatenate([d.statistics[k] for d in self.datasets], axis=0) for k in self.datasets[0].statistics } def source(self, index): @@ -1160,71 +1115,6 @@ def missing(self): return {self.indices[i] for i in self.dataset.missing if i in self.indices} -class ValidRanges(Subset): - def __init__(self, dataset, width): - self.width = width - self.ranges = dataset.valid_ranges(width) - - super().__init__(dataset, [i for r in self.ranges for i in range(*r)]) - - assert not self.missing, self.missing - - @debug_indexing - def _get_slice(self, s): - self._check_slice(s) - return super()._get_slice(s) - - @debug_indexing - @expand_list_indexing - def _get_tuple(self, n): - if isinstance(n[0], int): - return super()._get_tuple(n) - - if isinstance(n[0], slice): - self._check_slice(n[0]) - return super()._get_tuple(n) - - if isinstance(n[0], (list, tuple)): - self._check_list(n[0]) - return super()._get_tuple(n) - - raise TypeError(f"Unsupported index {n} {type(n)}") - - def _check_slice(self, s): - """Check that once mapped in the new index, the slice is regular.""" - - indices = [self.indices[i] for i in range(*s.indices(self._len))] - self._check_indices(s, indices) - - def _check_list(self, s): - indices = [self.indices[i] for i in s] - self._check_indices(s, indices) - - def _check_indices(self, s, indices): - indices = _make_slice_or_index_from_list_or_tuple(indices) - if isinstance(indices, slice): - # The slice is regular in the new index, so we do not have - # missing dates - return - - dates = self.dataset.dates[indices] - delta = [dates[i] - dates[0] for i in range(1, len(dates))] - delta = [int(x.astype(object).total_seconds() / 3600) for x in delta] - delta = ",".join(["T"] + [f"T{x:+d}" for x in delta]) - - if isinstance(s, slice): - idx = [str(x) if x is not None else "" for x in (s.start, s.stop, s.step)] - idx = ":".join(idx) - if s.step is None: - idx = idx[:-1] - else: - idx = f"({','.join(str(x) for x in s)})" - - raise InvalidSlice( - f"Index [{idx}] does not cover a regular range: [{delta}] (width={self.width})" - ) - - class Select(Forwards): """Select a subset of the variables.""" @@ -1450,9 +1340,7 @@ def _concat_or_join(datasets, kwargs): for j in range(i + 1, len(ranges)): s = ranges[j] if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]: - raise ValueError( - f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})" - ) + raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})") # For now we should have the datasets in order with no gaps @@ -1463,8 +1351,7 @@ def _concat_or_join(datasets, kwargs): s = ranges[i + 1] if r[1] + datetime.timedelta(hours=frequency) != s[0]: raise ValueError( - "Datasets must be sorted by dates, with no gaps: " - f"{r} and {s} ({datasets[i]} {datasets[i+1]})" + "Datasets must be sorted by dates, with no gaps: " f"{r} and {s} ({datasets[i]} {datasets[i+1]})" ) return Concat(datasets), kwargs @@ -1570,9 +1457,7 @@ def _open_dataset(*args, zarr_root, **kwargs): "cutout": CutoutGrids, } if mode not in KLASSES: - raise ValueError( - f"Unknown grids mode: {mode}, values are {list(KLASSES.keys())}" - ) + raise ValueError(f"Unknown grids mode: {mode}, values are {list(KLASSES.keys())}") datasets = [_open(e, zarr_root) for e in grids] datasets, kwargs = _auto_adjust(datasets, kwargs) diff --git a/scripts/get-grids.py b/scripts/get-grids.py deleted file mode 100755 index 704cabd..0000000 --- a/scripts/get-grids.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys - -import climetlab as cml -import numpy as np - -grid = sys.argv[1] - -ds = cml.load_source("file", f"{grid}.grib") - - -field = ds[0] - -latitudes, longitudes = field.grid_points() - -path = f"grid-{grid}.npz" - -np.savez( - "tmp.npz", - latitudes=latitudes, - longitudes=longitudes, -) -os.rename("tmp.npz", path) diff --git a/scripts/get-grids.sh b/scripts/get-grids.sh deleted file mode 100755 index e3df93f..0000000 --- a/scripts/get-grids.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -mir=/usr/local/apps/mars/versions/6.33.17.6/bin/mir - -if [[ ! -f 2t.grib ]]; then -mars<