Skip to content

Commit

Permalink
merge commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Rilwan-Adewoyin committed Jan 10, 2025
2 parents 6ea4aa8 + 91e8b96 commit 47dae0c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
23 changes: 10 additions & 13 deletions src/anemoi/datasets/data/forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@

from .dataset import Dataset
from .debug import debug_indexing
from .indexing import apply_index_to_slices_changes
from .indexing import _extend_shape
from .indexing import expand_list_indexing
from .indexing import index_to_slices
from .indexing import length_to_slices
from .indexing import update_tuple
from .indexing import get_indices_for_child_datasets_from_combined_axis_index
from .indexing import _extend_shape

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -261,36 +257,37 @@ def shape(self):
# result = [d[update_tuple(index, self.axis, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
# result = np.concatenate(result, axis=self.axis)
# return apply_index_to_slices_changes(result, changes)

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):

index = _extend_shape(index, self.shape)

lengths = [d.shape[self.axis] for d in self.datasets]

# New logic required here
# If we are indexing along the axis which is also the axis that the two datasets are combined along,
# If we are indexing along the axis which is also the axis that the two datasets are combined along,
# we need to ensure we collect the correct indices from each of the child datasets:
# 1) ascertain the lengths of each child dataset in the combined axis
# 2) create new indexing indices for each child dataset by correctly adjusting the combined axis index for each child
# 3) index each child dataset using the new indices
# 4) concatenate the results in an order that matches the original combined axis index

index_children:list[tuple[slice,list[int]]] = get_indices_for_child_datasets_from_combined_axis_index(self.axis, index, lengths)
index_children: list[tuple[slice, list[int]]] = get_indices_for_child_datasets_from_combined_axis_index(
self.axis, index, lengths
)

child_datasets = [d.__getitem__(index_child) for d, index_child in zip(self.datasets, index_children)]

# Interleaving logic not needed since the datasets and its selections are already in the correct order
# Here we interleave the results in the order of the original combined axis index
# This is done by concatenating the results in the order of the original combined axis index
# result = interleave_child_datasets_on_combined_axis(self.axis, child_datasets, index, lengths)

result = np.concatenate(child_datasets, axis=self.axis)
return result



@debug_indexing
def _get_slice(self, s):
return np.stack([self[i] for i in range(*s.indices(self._len))])
Expand Down
29 changes: 19 additions & 10 deletions src/anemoi/datasets/data/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
# nor does it submit to any jurisdiction.


import copy
from functools import wraps

import numpy as np
import copy


def _tuple_with_slices(t, shape):
"""Replace all integers in a tuple with slices, so we preserve the dimensionality."""
Expand Down Expand Up @@ -153,6 +154,7 @@ def _(i):

# return wrapper


def expand_list_indexing(method):
"""Allows to use slices, lists, and tuples to select data from the dataset. Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves."""

Expand All @@ -163,11 +165,12 @@ def wrapper(self, index):

if not any(isinstance(i, (list, tuple)) for i in index):
return method(self, index)

return method(self, index)

return wrapper


def make_slice_or_index_from_list_or_tuple(indices):
"""Convert a list or tuple of indices to a slice or an index, if possible."""

Expand All @@ -181,7 +184,10 @@ def make_slice_or_index_from_list_or_tuple(indices):

return indices

def get_indices_for_child_datasets_from_combined_axis_index(join_axis:int, index_combined:tuple[slice|list[int],], child_datasets_axis_lengths:list[int]) -> list[tuple[slice,list[int]]]:

def get_indices_for_child_datasets_from_combined_axis_index(
join_axis: int, index_combined: tuple[slice | list[int],], child_datasets_axis_lengths: list[int]
) -> list[tuple[slice, list[int]]]:
"""Given a combined axis index, and the axis along which the child datasets are joined, return the indices for each child dataset."""

# 1) ascertain the lengths of each child dataset in the combined axis
Expand All @@ -192,18 +198,18 @@ def get_indices_for_child_datasets_from_combined_axis_index(join_axis:int, index
cumulative_lengths = np.cumsum(child_datasets_axis_lengths)

start_indices_child = [0] + cumulative_lengths[:-1].tolist()
end_indices_child = [s+l for s,l in zip(start_indices_child, child_datasets_axis_lengths)]
end_indices_child = [s + l for s, l in zip(start_indices_child, child_datasets_axis_lengths)]

index_children = []
for idx_child in range(len(child_datasets_axis_lengths)):
index_child = list( copy.deepcopy(index_combined) )
index_child = list(copy.deepcopy(index_combined))

# in the join axis, select the indices that map to the child dataset at position i
index_at_join_axis = index_child[join_axis]

if isinstance(index_at_join_axis, slice):
start, stop, step = index_at_join_axis.indices(cumulative_lengths[idx_child])

# Ensure the slice is within the bounds
start = max(start, start_indices_child[idx_child])
stop = min(stop, end_indices_child[idx_child])
Expand All @@ -213,14 +219,17 @@ def get_indices_for_child_datasets_from_combined_axis_index(join_axis:int, index
new_index_at_join_axis = slice(adjusted_start, adjusted_stop, step)
index_child[join_axis] = new_index_at_join_axis


elif isinstance(index_at_join_axis, list):
new_index_at_join_aixs = [ idx for idx in index_at_join_axis if idx >= start_indices_child[idx_child] and idx < end_indices_child[idx_child] ]
adjusted_index_at_join_axis = [ idx - start_indices_child[idx_child] for idx in new_index_at_join_aixs ]
new_index_at_join_aixs = [
idx
for idx in index_at_join_axis
if idx >= start_indices_child[idx_child] and idx < end_indices_child[idx_child]
]
adjusted_index_at_join_axis = [idx - start_indices_child[idx_child] for idx in new_index_at_join_aixs]
index_child[join_axis] = adjusted_index_at_join_axis
else:
ValueError("Index at join axis is not a slice or a list")

index_child = tuple(index_child)
index_children.append(index_child)
return index_children
7 changes: 2 additions & 5 deletions src/anemoi/datasets/data/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from .debug import Source
from .debug import debug_indexing
from .forwards import Forwards
from .indexing import apply_index_to_slices_changes
from .indexing import expand_list_indexing
from .indexing import index_to_slices
from .indexing import update_tuple

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,14 +56,14 @@ def mutate(self):
# result = result[:, previous]
# result = apply_index_to_slices_changes(result, changes)
# return result

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
result = self.dataset[index]
result = result[:, self.indices]
return result

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple):
Expand Down
3 changes: 0 additions & 3 deletions src/anemoi/datasets/data/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
from .debug import Source
from .debug import debug_indexing
from .forwards import Forwards
from .indexing import apply_index_to_slices_changes
from .indexing import expand_list_indexing
from .indexing import index_to_slices
from .indexing import make_slice_or_index_from_list_or_tuple
from .indexing import update_tuple

LOG = logging.getLogger(__name__)

Expand Down

0 comments on commit 47dae0c

Please sign in to comment.