Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loader-specific args in iris load functions. #3720

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions lib/iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def context(self, **kwargs):
_update(site_configuration)


def _generate_cubes(uris, callback, constraints):
def _generate_cubes(uris, callback, constraints, loader_kwargs=None):
"""Returns a generator of cubes given the URIs and a callback."""
if isinstance(uris, (str, pathlib.PurePath)):
uris = [uris]
Expand All @@ -253,21 +253,25 @@ def _generate_cubes(uris, callback, constraints):
# Call each scheme handler with the appropriate URIs
if scheme == "file":
part_names = [x[1] for x in groups]
for cube in iris.io.load_files(part_names, callback, constraints):
for cube in iris.io.load_files(
part_names, callback, constraints, loader_kwargs
):
yield cube
elif scheme in ["http", "https"]:
urls = [":".join(x) for x in groups]
for cube in iris.io.load_http(urls, callback):
for cube in iris.io.load_http(urls, callback, loader_kwargs):
yield cube
else:
raise ValueError("Iris cannot handle the URI scheme: %s" % scheme)


def _load_collection(uris, constraints=None, callback=None):
def _load_collection(
uris, constraints=None, callback=None, loader_kwargs=None
):
from iris.cube import _CubeFilterCollection

try:
cubes = _generate_cubes(uris, callback, constraints)
cubes = _generate_cubes(uris, callback, constraints, loader_kwargs)
result = _CubeFilterCollection.from_cubes(cubes, constraints)
except EOFError as e:
raise iris.exceptions.TranslationError(
Expand All @@ -276,7 +280,7 @@ def _load_collection(uris, constraints=None, callback=None):
return result


def load(uris, constraints=None, callback=None):
def load(uris, constraints=None, callback=None, loader_kwargs=None):
"""
Loads any number of Cubes for each constraint.

Expand All @@ -294,17 +298,23 @@ def load(uris, constraints=None, callback=None):
One or more constraints.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.

Returns:
An :class:`iris.cube.CubeList`. Note that there is no inherent order
to this :class:`iris.cube.CubeList` and it should be treated as if it
were random.

"""
return _load_collection(uris, constraints, callback).merged().cubes()
return (
_load_collection(uris, constraints, callback, loader_kwargs)
.merged()
.cubes()
)


def load_cube(uris, constraint=None, callback=None):
def load_cube(uris, constraint=None, callback=None, loader_kwargs=None):
"""
Loads a single cube.

Expand All @@ -322,6 +332,8 @@ def load_cube(uris, constraint=None, callback=None):
A constraint.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.

Returns:
An :class:`iris.cube.Cube`.
Expand All @@ -331,7 +343,9 @@ def load_cube(uris, constraint=None, callback=None):
if len(constraints) != 1:
raise ValueError("only a single constraint is allowed")

cubes = _load_collection(uris, constraints, callback).cubes()
cubes = _load_collection(
uris, constraints, callback, loader_kwargs
).cubes()

try:
cube = cubes.merge_cube()
Expand All @@ -343,7 +357,7 @@ def load_cube(uris, constraint=None, callback=None):
return cube


def load_cubes(uris, constraints=None, callback=None):
def load_cubes(uris, constraints=None, callback=None, loader_kwargs=None):
"""
Loads exactly one Cube for each constraint.

Expand All @@ -361,6 +375,8 @@ def load_cubes(uris, constraints=None, callback=None):
One or more constraints.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.

Returns:
An :class:`iris.cube.CubeList`. Note that there is no inherent order
Expand All @@ -369,7 +385,9 @@ def load_cubes(uris, constraints=None, callback=None):

"""
# Merge the incoming cubes
collection = _load_collection(uris, constraints, callback).merged()
collection = _load_collection(
uris, constraints, callback, loader_kwargs
).merged()

# Make sure we have exactly one merged cube per constraint
bad_pairs = [pair for pair in collection.pairs if len(pair) != 1]
Expand All @@ -382,7 +400,7 @@ def load_cubes(uris, constraints=None, callback=None):
return collection.cubes()


def load_raw(uris, constraints=None, callback=None):
def load_raw(uris, constraints=None, callback=None, loader_kwargs=None):
"""
Loads non-merged cubes.

Expand All @@ -406,6 +424,8 @@ def load_raw(uris, constraints=None, callback=None):
One or more constraints.
* callback:
A modifier/filter function.
* loader_kwargs (dict):
Additional settings, specific to the file-format loader.

Returns:
An :class:`iris.cube.CubeList`.
Expand All @@ -414,7 +434,9 @@ def load_raw(uris, constraints=None, callback=None):
from iris.fileformats.um._fast_load import _raw_structured_loading

with _raw_structured_loading():
return _load_collection(uris, constraints, callback).cubes()
return _load_collection(
uris, constraints, callback, loader_kwargs
).cubes()


save = iris.io.save
Expand Down
20 changes: 15 additions & 5 deletions lib/iris/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def expand_filespecs(file_specs):
return [fname for fnames in all_expanded for fname in fnames]


def load_files(filenames, callback, constraints=None):
def load_files(filenames, callback, constraints=None, loader_kwargs=None):
"""
Takes a list of filenames which may also be globs, and optionally a
constraint set and a callback function, and returns a
Expand All @@ -202,19 +202,24 @@ def load_files(filenames, callback, constraints=None):
handler_map[handling_format_spec].append(fn)

# Call each iris format handler with the approriate filenames
if loader_kwargs is None:
loader_kwargs = {}

for handling_format_spec in sorted(handler_map):
fnames = handler_map[handling_format_spec]
if handling_format_spec.constraint_aware_handler:
for cube in handling_format_spec.handler(
fnames, callback, constraints
fnames, callback, constraints, **loader_kwargs
):
yield cube
else:
for cube in handling_format_spec.handler(fnames, callback):
for cube in handling_format_spec.handler(
fnames, callback, **loader_kwargs
):
yield cube


def load_http(urls, callback):
def load_http(urls, callback, loader_kwargs=None):
"""
Takes a list of urls and a callback function, and returns a generator
of Cubes from the given URLs.
Expand All @@ -234,9 +239,14 @@ def load_http(urls, callback):
handler_map[handling_format_spec].append(url)

# Call each iris format handler with the appropriate filenames
if loader_kwargs is None:
loader_kwargs = {}

for handling_format_spec in sorted(handler_map):
fnames = handler_map[handling_format_spec]
for cube in handling_format_spec.handler(fnames, callback):
for cube in handling_format_spec.handler(
fnames, callback, **loader_kwargs
):
yield cube


Expand Down
145 changes: 145 additions & 0 deletions lib/iris/tests/unit/fileformats/test_loader_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright Iris contributors
#
# This file is part of Iris and is released under the LGPL license.
# See COPYING and COPYING.LESSER in the root of the repository for full
# licensing details.
"""
Test iris.load functions support for loader-specific keywords.

"""

import os.path
import shutil
import tempfile

import iris
from iris.cube import Cube, CubeList
from iris.fileformats import FORMAT_AGENT
from iris.io.format_picker import FileExtension, FormatSpecification

# import iris tests first so that some things can be initialised before
# importing anything else
import iris.tests as tests


# Define an extra 'dummy' FormatHandler for load-kwarg testing.
# We add this, temporarily, to the Iris file-format picker for testing.
def format_x_handler_function(fnames, callback, **loader_kwargs):
# A format handler function.
# Yield a single cube, with its attributes set to the call arguments.
call_args = dict(
fnames=fnames, callback=callback, loader_kwargs=loader_kwargs
)
cube = Cube(1, attributes=call_args)
yield cube


# A Format spec for the fake file format.
FORMAT_X_EXTENSION = "._xtst_"
FORMAT_X = FormatSpecification(
"Testing file handler",
FileExtension(),
FORMAT_X_EXTENSION,
format_x_handler_function,
)


class LoadFunctionMixin:
# Common code to test load/load_cube/load_cubes/load_raw.

# Inheritor must set this to the name of an iris load function.
# Note: storing the function itself causes it to be mis-identified as an
# instance method when called, so doing it by name is clearer.
load_function_name = "xxx"

@classmethod
def setUpClass(cls):
# Add our dummy format handler to the common Iris io file-picker.
FORMAT_AGENT.add_spec(FORMAT_X)
# Create a temporary working directory.
cls._temp_dir = tempfile.mkdtemp()
# Store the path of a dummy file whose name matches the picker.
filename = "testfile" + FORMAT_X_EXTENSION
cls.test_filepath = os.path.join(cls._temp_dir, filename)
# Create the dummy file.
with open(cls.test_filepath, "w") as f_open:
# Write some data to satisfy the other, signature-based pickers.
# TODO: this really shouldn't be necessary ??
f_open.write("\x00" * 100)

@classmethod
def tearDownClass(cls):
# Remove the dummy format handler.
# N.B. no public api, so uses a private property of FormatAgent.
FORMAT_AGENT._format_specs.remove(FORMAT_X)
# Delete the temporary directory.
shutil.rmtree(cls._temp_dir)

def _load_a_cube(self, *args, **kwargs):
load_function = getattr(iris, self.load_function_name)
result = load_function(*args, loader_kwargs=kwargs)
if load_function is not iris.load_cube:
# Handle 'other' load functions, which return CubeLists ...
self.assertIsInstance(result, CubeList)
# ... however, we intend that all uses will return only 1 cube.
self.assertEqual(len(result), 1)
result = result[0]
self.assertIsInstance(result, Cube)
return result

def test_extra_args(self):
test_kwargs = {"loader_a": 1, "loader_b": "two"}
result = self._load_a_cube(self.test_filepath, **test_kwargs)
self.assertEqual(
result.attributes,
dict(
fnames=[self.test_filepath],
callback=None,
loader_kwargs=test_kwargs,
),
)

def test_no_extra_args(self):
result = self._load_a_cube(self.test_filepath)
self.assertEqual(
result.attributes,
dict(fnames=[self.test_filepath], callback=None, loader_kwargs={}),
)

@tests.skip_data
def test_wrong_loader_noargs_ok(self):
filepath = tests.get_data_path(
["NetCDF", "global", "xyz_t", "GEMS_CO2_Apr2006.nc"]
)
result = self._load_a_cube(filepath, "co2")
self.assertIsNot(result, None)

@tests.skip_data
def test_wrong_loader_withargs__fail(self):
filepath = tests.get_data_path(
["NetCDF", "global", "xyz_t", "GEMS_CO2_Apr2006.nc"]
)
test_kwargs = {"junk": "this"}
msg = "load.* got an unexpected keyword argument 'junk'"
with self.assertRaisesRegex(TypeError, msg):
_ = self._load_a_cube(filepath, "co2", **test_kwargs)


class TestLoad(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load"


class TestLoadCubes(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load_cubes"


class TestLoadCube(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load_cube"


class TestLoadRaw(LoadFunctionMixin, tests.IrisTest):
load_function_name = "load_raw"


if __name__ == "__main__":
tests.main()