Skip to content

Commit

Permalink
Fix split_on option for retrievals (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
sandorkertesz authored Feb 20, 2024
1 parent 2a33ace commit 942829b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 44 deletions.
105 changes: 63 additions & 42 deletions earthkit/data/sources/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,57 @@
LOG = logging.getLogger(__name__)


class FileParts:
def __init__(self, path, parts):
self.path, self.parts = self._paths_and_parts(path, parts)

def is_empty(self):
return not (self.parts is not None and any(x is not None for x in self.parts))

def update(self, path):
if self.path != path:
self.path, self.parts = self._paths_and_parts(path, self.parts)
return self.path

def _paths_and_parts(self, paths, parts):
"""Preprocess paths and parts.
Parameters
----------
paths: str or list/tuple
The path(s). When it is a sequence either each
item is a path (str), or a pair of a path and :ref:`parts <parts>`.
parts: part,list/tuple of parts or None.
The :ref:`parts <parts>`.
Returns
-------
str or list of str
The path or paths.
SimplePart, list or tuple, None
The parts (one for each path). A part can be a single
SimplePart, a list/tuple of SimpleParts or None.
"""
if parts is None:
if isinstance(paths, str):
return paths, None
elif isinstance(paths, (list, tuple)) and all(
isinstance(p, str) for p in paths
):
return paths, [None] * len(paths)

paths = check_urls_and_parts(paths, parts)
paths_and_parts = ensure_urls_and_parts(paths, parts, compress=True)

paths, parts = zip(*paths_and_parts)
assert len(paths) == len(parts)
if len(paths) == 1:
return paths[0], parts[0]
else:
return paths, parts


class FileSourceMeta(type(Source), type(os.PathLike)):
def patch(cls, obj, *args, **kwargs):
if "reader" in kwargs:
Expand All @@ -39,13 +90,18 @@ def __init__(self, path=None, filter=None, merger=None, parts=None, **kwargs):
Source.__init__(self, **kwargs)
self.filter = filter
self.merger = merger
self.path, self.parts = self._paths_and_parts(path, parts)
self._parts = FileParts(path, parts)
self.path = self._parts.path

if self._kwargs.get("indexing", False):
if self.parts is not None and any(x is not None for x in self.parts):
if not self._parts.is_empty():
raise ValueError("Cannot specify parts when indexing is enabled!")

def mutate(self):
# the initial path is reset for e.g. the retrievals. We have to ensure
# the parts are still correctly formed
self.check_parts()

if isinstance(self.path, (list, tuple)):
if len(self.path) == 1:
self.path = self.path[0]
Expand All @@ -54,7 +110,7 @@ def mutate(self):
"multi",
[
from_source("file", p, parts=part, **self._kwargs)
for p, part in zip(self.path, self.parts)
for p, part in zip(self.path, self._parts.parts)
],
filter=self.filter,
merger=self.merger,
Expand Down Expand Up @@ -89,8 +145,9 @@ def merge(cls, sources):
@property
def _reader(self):
if self._reader_ is None:
self.check_parts()
self._reader_ = reader(
self, self.path, content_type=self.content_type, parts=self.parts
self, self.path, content_type=self.content_type, parts=self._parts.parts
)
return self._reader_

Expand Down Expand Up @@ -189,44 +246,8 @@ def bounding_box(self):
def statistics(self, **kwargs):
return self._reader.statistics(**kwargs)

@staticmethod
def _paths_and_parts(paths, parts):
"""Preprocess paths and parts.
Parameters
----------
paths: str or list/tuple
The path(s). When it is a sequence either each
item is a path (str), or a pair of a path and :ref:`parts <parts>`.
parts: part,list/tuple of parts or None.
The :ref:`parts <parts>`.
Returns
-------
str or list of str
The path or paths.
SimplePart, list or tuple, None
The parts (one for each path). A part can be a single
SimplePart, a list/tuple of SimpleParts or None.
"""
if parts is None:
if isinstance(paths, str):
return paths, None
elif isinstance(paths, (list, tuple)) and all(
isinstance(p, str) for p in paths
):
return paths, [None] * len(paths)

paths = check_urls_and_parts(paths, parts)
paths_and_parts = ensure_urls_and_parts(paths, parts, compress=True)

paths, parts = zip(*paths_and_parts)
assert len(paths) == len(parts)
if len(paths) == 1:
return paths[0], parts[0]
else:
return paths, parts
def check_parts(self):
self.path = self._parts.update(self.path)


class IndexedFileSource(FileSource):
Expand Down
21 changes: 19 additions & 2 deletions tests/utils/test_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pytest

from earthkit.data.sources.file import FileSource
from earthkit.data.sources.file import FileParts
from earthkit.data.sources.url import Url
from earthkit.data.utils.parts import SimplePart

Expand Down Expand Up @@ -81,11 +81,28 @@
],
)
def test_prepare_file_parts(paths, parts, expected_paths, expected_parts):
res_paths, res_parts = FileSource._paths_and_parts(paths, parts)
p = FileParts(paths, parts)
res_paths, res_parts = p.path, p.parts
assert res_paths == expected_paths
assert res_parts == expected_parts


def test_update_file_parts():
p = FileParts("", None)
assert p.path == ""
assert p.parts is None

res = p.update("a.grib")
assert res == "a.grib"
assert p.path == res
assert p.parts is None

res = p.update(["a.grib", "b.grib"])
assert res == ["a.grib", "b.grib"]
assert p.path == res
assert p.parts == [None, None]


@pytest.mark.parametrize(
"urls,parts,expected_values",
[
Expand Down

0 comments on commit 942829b

Please sign in to comment.