Skip to content

Commit

Permalink
Support optional loader keys in iris.save functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
pp-mo committed Feb 7, 2022
1 parent 1d5eb3e commit e031548
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 18 deletions.
48 changes: 35 additions & 13 deletions lib/iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def context(self, **kwargs):
_update(site_configuration)


def _generate_cubes(uris, callback, constraints):
def _generate_cubes(uris, callback, constraints, loader_kwargs=None):
"""Returns a generator of cubes given the URIs and a callback."""
if isinstance(uris, (str, pathlib.PurePath)):
uris = [uris]
Expand All @@ -253,21 +253,25 @@ def _generate_cubes(uris, callback, constraints):
# Call each scheme handler with the appropriate URIs
if scheme == "file":
part_names = [x[1] for x in groups]
for cube in iris.io.load_files(part_names, callback, constraints):
for cube in iris.io.load_files(
part_names, callback, constraints, loader_kwargs
):
yield cube
elif scheme in ["http", "https"]:
urls = [":".join(x) for x in groups]
for cube in iris.io.load_http(urls, callback):
for cube in iris.io.load_http(urls, callback, loader_kwargs):
yield cube
else:
raise ValueError("Iris cannot handle the URI scheme: %s" % scheme)


def _load_collection(uris, constraints=None, callback=None):
def _load_collection(
uris, constraints=None, callback=None, loader_kwargs=None
):
from iris.cube import _CubeFilterCollection

try:
cubes = _generate_cubes(uris, callback, constraints)
cubes = _generate_cubes(uris, callback, constraints, loader_kwargs)
result = _CubeFilterCollection.from_cubes(cubes, constraints)
except EOFError as e:
raise iris.exceptions.TranslationError(
Expand All @@ -276,7 +280,7 @@ def _load_collection(uris, constraints=None, callback=None):
return result


def load(uris, constraints=None, callback=None):
def load(uris, constraints=None, callback=None, loader_kwargs=None):
"""
Loads any number of Cubes for each constraint.
Expand All @@ -294,17 +298,23 @@ def load(uris, constraints=None, callback=None):
One or more constraints.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.
Returns:
An :class:`iris.cube.CubeList`. Note that there is no inherent order
to this :class:`iris.cube.CubeList` and it should be treated as if it
were random.
"""
return _load_collection(uris, constraints, callback).merged().cubes()
return (
_load_collection(uris, constraints, callback, loader_kwargs)
.merged()
.cubes()
)


def load_cube(uris, constraint=None, callback=None):
def load_cube(uris, constraint=None, callback=None, loader_kwargs=None):
"""
Loads a single cube.
Expand All @@ -322,6 +332,8 @@ def load_cube(uris, constraint=None, callback=None):
A constraint.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.
Returns:
An :class:`iris.cube.Cube`.
Expand All @@ -331,7 +343,9 @@ def load_cube(uris, constraint=None, callback=None):
if len(constraints) != 1:
raise ValueError("only a single constraint is allowed")

cubes = _load_collection(uris, constraints, callback).cubes()
cubes = _load_collection(
uris, constraints, callback, loader_kwargs
).cubes()

try:
cube = cubes.merge_cube()
Expand All @@ -343,7 +357,7 @@ def load_cube(uris, constraint=None, callback=None):
return cube


def load_cubes(uris, constraints=None, callback=None):
def load_cubes(uris, constraints=None, callback=None, loader_kwargs=None):
"""
Loads exactly one Cube for each constraint.
Expand All @@ -361,6 +375,8 @@ def load_cubes(uris, constraints=None, callback=None):
One or more constraints.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.
Returns:
An :class:`iris.cube.CubeList`. Note that there is no inherent order
Expand All @@ -369,7 +385,9 @@ def load_cubes(uris, constraints=None, callback=None):
"""
# Merge the incoming cubes
collection = _load_collection(uris, constraints, callback).merged()
collection = _load_collection(
uris, constraints, callback, loader_kwargs
).merged()

# Make sure we have exactly one merged cube per constraint
bad_pairs = [pair for pair in collection.pairs if len(pair) != 1]
Expand All @@ -382,7 +400,7 @@ def load_cubes(uris, constraints=None, callback=None):
return collection.cubes()


def load_raw(uris, constraints=None, callback=None):
def load_raw(uris, constraints=None, callback=None, loader_kwargs=None):
"""
Loads non-merged cubes.
Expand All @@ -406,6 +424,8 @@ def load_raw(uris, constraints=None, callback=None):
One or more constraints.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.
Returns:
An :class:`iris.cube.CubeList`.
Expand All @@ -414,7 +434,9 @@ def load_raw(uris, constraints=None, callback=None):
from iris.fileformats.um._fast_load import _raw_structured_loading

with _raw_structured_loading():
return _load_collection(uris, constraints, callback).cubes()
return _load_collection(
uris, constraints, callback, loader_kwargs
).cubes()


save = iris.io.save
Expand Down
20 changes: 15 additions & 5 deletions lib/iris/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def expand_filespecs(file_specs):
return [fname for fnames in all_expanded for fname in fnames]


def load_files(filenames, callback, constraints=None):
def load_files(filenames, callback, constraints=None, loader_kwargs=None):
"""
Takes a list of filenames which may also be globs, and optionally a
constraint set and a callback function, and returns a
Expand All @@ -202,19 +202,24 @@ def load_files(filenames, callback, constraints=None):
handler_map[handling_format_spec].append(fn)

# Call each iris format handler with the approriate filenames
if loader_kwargs is None:
loader_kwargs = {}

for handling_format_spec in sorted(handler_map):
fnames = handler_map[handling_format_spec]
if handling_format_spec.constraint_aware_handler:
for cube in handling_format_spec.handler(
fnames, callback, constraints
fnames, callback, constraints, **loader_kwargs
):
yield cube
else:
for cube in handling_format_spec.handler(fnames, callback):
for cube in handling_format_spec.handler(
fnames, callback, **loader_kwargs
):
yield cube


def load_http(urls, callback):
def load_http(urls, callback, loader_kwargs=None):
"""
Takes a list of urls and a callback function, and returns a generator
of Cubes from the given URLs.
Expand All @@ -234,9 +239,14 @@ def load_http(urls, callback):
handler_map[handling_format_spec].append(url)

# Call each iris format handler with the appropriate filenames
if loader_kwargs is None:
loader_kwargs = {}

for handling_format_spec in sorted(handler_map):
fnames = handler_map[handling_format_spec]
for cube in handling_format_spec.handler(fnames, callback):
for cube in handling_format_spec.handler(
fnames, callback, **loader_kwargs
):
yield cube


Expand Down
145 changes: 145 additions & 0 deletions lib/iris/tests/unit/fileformats/test_loader_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright Iris contributors
#
# This file is part of Iris and is released under the LGPL license.
# See COPYING and COPYING.LESSER in the root of the repository for full
# licensing details.
"""
Test iris.load functions support for loader-specific keywords.
"""

import os.path
import shutil
import tempfile

import iris
from iris.cube import Cube, CubeList
from iris.fileformats import FORMAT_AGENT
from iris.io.format_picker import FileExtension, FormatSpecification

# import iris tests first so that some things can be initialised before
# importing anything else
import iris.tests as tests


# Define an extra 'dummy' FormatHandler for load-kwarg testing.
# We add this, temporarily, to the Iris file-format picker for testing.
def format_x_handler_function(fnames, callback, **loader_kwargs):
# A format handler function.
# Yield a single cube, with its attributes set to the call arguments.
call_args = dict(
fnames=fnames, callback=callback, loader_kwargs=loader_kwargs
)
cube = Cube(1, attributes=call_args)
yield cube


# A Format spec for the fake file format.
FORMAT_X_EXTENSION = "._xtst_"
FORMAT_X = FormatSpecification(
"Testing file handler",
FileExtension(),
FORMAT_X_EXTENSION,
format_x_handler_function,
)


class LoadFunctionMixin:
# Common code to test load/load_cube/load_cubes/load_raw.

# Inheritor must set this to the name of an iris load function.
# Note: storing the function itself causes it to be mis-identified as an
# instance method when called, so doing it by name is clearer.
load_function_name = "xxx"

@classmethod
def setUpClass(cls):
# Add our dummy format handler to the common Iris io file-picker.
FORMAT_AGENT.add_spec(FORMAT_X)
# Create a temporary working directory.
cls._temp_dir = tempfile.mkdtemp()
# Store the path of a dummy file whose name matches the picker.
filename = "testfile" + FORMAT_X_EXTENSION
cls.test_filepath = os.path.join(cls._temp_dir, filename)
# Create the dummy file.
with open(cls.test_filepath, "w") as f_open:
# Write some data to satisfy the other, signature-based pickers.
# TODO: this really shouldn't be necessary ??
f_open.write("\x00" * 100)

@classmethod
def tearDownClass(cls):
# Remove the dummy format handler.
# N.B. no public api, so uses a private property of FormatAgent.
FORMAT_AGENT._format_specs.remove(FORMAT_X)
# Delete the temporary directory.
shutil.rmtree(cls._temp_dir)

def _load_a_cube(self, *args, **kwargs):
load_function = getattr(iris, self.load_function_name)
result = load_function(*args, loader_kwargs=kwargs)
if load_function is not iris.load_cube:
# Handle 'other' load functions, which return CubeLists ...
self.assertIsInstance(result, CubeList)
# ... however, we intend that all uses will return only 1 cube.
self.assertEqual(len(result), 1)
result = result[0]
self.assertIsInstance(result, Cube)
return result

def test_extra_args(self):
test_kwargs = {"loader_a": 1, "loader_b": "two"}
result = self._load_a_cube(self.test_filepath, **test_kwargs)
self.assertEqual(
result.attributes,
dict(
fnames=[self.test_filepath],
callback=None,
loader_kwargs=test_kwargs,
),
)

def test_no_extra_args(self):
result = self._load_a_cube(self.test_filepath)
self.assertEqual(
result.attributes,
dict(fnames=[self.test_filepath], callback=None, loader_kwargs={}),
)

@tests.skip_data
def test_wrong_loader_noargs_ok(self):
filepath = tests.get_data_path(
["NetCDF", "global", "xyz_t", "GEMS_CO2_Apr2006.nc"]
)
result = self._load_a_cube(filepath, "co2")
self.assertIsNot(result, None)

@tests.skip_data
def test_wrong_loader_withargs__fail(self):
filepath = tests.get_data_path(
["NetCDF", "global", "xyz_t", "GEMS_CO2_Apr2006.nc"]
)
test_kwargs = {"junk": "this"}
msg = "load.* got an unexpected keyword argument 'junk'"
with self.assertRaisesRegex(TypeError, msg):
_ = self._load_a_cube(filepath, "co2", **test_kwargs)


class TestLoad(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load"


class TestLoadCubes(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load_cubes"


class TestLoadCube(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load_cube"


class TestLoadRaw(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load_raw"


if __name__ == "__main__":
tests.main()

0 comments on commit e031548

Please sign in to comment.