diff --git a/lib/iris/__init__.py b/lib/iris/__init__.py index 713c163deb..4cf977164a 100644 --- a/lib/iris/__init__.py +++ b/lib/iris/__init__.py @@ -240,7 +240,7 @@ def context(self, **kwargs): _update(site_configuration) -def _generate_cubes(uris, callback, constraints): +def _generate_cubes(uris, callback, constraints, loader_kwargs=None): """Returns a generator of cubes given the URIs and a callback.""" if isinstance(uris, (str, pathlib.PurePath)): uris = [uris] @@ -253,21 +253,25 @@ def _generate_cubes(uris, callback, constraints): # Call each scheme handler with the appropriate URIs if scheme == "file": part_names = [x[1] for x in groups] - for cube in iris.io.load_files(part_names, callback, constraints): + for cube in iris.io.load_files( + part_names, callback, constraints, loader_kwargs + ): yield cube elif scheme in ["http", "https"]: urls = [":".join(x) for x in groups] - for cube in iris.io.load_http(urls, callback): + for cube in iris.io.load_http(urls, callback, loader_kwargs): yield cube else: raise ValueError("Iris cannot handle the URI scheme: %s" % scheme) -def _load_collection(uris, constraints=None, callback=None): +def _load_collection( + uris, constraints=None, callback=None, loader_kwargs=None +): from iris.cube import _CubeFilterCollection try: - cubes = _generate_cubes(uris, callback, constraints) + cubes = _generate_cubes(uris, callback, constraints, loader_kwargs) result = _CubeFilterCollection.from_cubes(cubes, constraints) except EOFError as e: raise iris.exceptions.TranslationError( @@ -276,7 +280,7 @@ def _load_collection(uris, constraints=None, callback=None): return result -def load(uris, constraints=None, callback=None): +def load(uris, constraints=None, callback=None, loader_kwargs=None): """ Loads any number of Cubes for each constraint. @@ -294,6 +298,8 @@ def load(uris, constraints=None, callback=None): One or more constraints. * callback: A modifier/filter function. + * loader_kwargs (dict): + Additional settings, specific to the file-format loader. Returns: An :class:`iris.cube.CubeList`. Note that there is no inherent order @@ -301,10 +307,14 @@ def load(uris, constraints=None, callback=None): were random. """ - return _load_collection(uris, constraints, callback).merged().cubes() + return ( + _load_collection(uris, constraints, callback, loader_kwargs) + .merged() + .cubes() + ) -def load_cube(uris, constraint=None, callback=None): +def load_cube(uris, constraint=None, callback=None, loader_kwargs=None): """ Loads a single cube. @@ -322,6 +332,8 @@ def load_cube(uris, constraint=None, callback=None): A constraint. * callback: A modifier/filter function. + * loader_kwargs (dict): + Additional settings, specific to the file-format loader. Returns: An :class:`iris.cube.Cube`. @@ -331,7 +343,9 @@ def load_cube(uris, constraint=None, callback=None): if len(constraints) != 1: raise ValueError("only a single constraint is allowed") - cubes = _load_collection(uris, constraints, callback).cubes() + cubes = _load_collection( + uris, constraints, callback, loader_kwargs + ).cubes() try: cube = cubes.merge_cube() @@ -343,7 +357,7 @@ def load_cube(uris, constraint=None, callback=None): return cube -def load_cubes(uris, constraints=None, callback=None): +def load_cubes(uris, constraints=None, callback=None, loader_kwargs=None): """ Loads exactly one Cube for each constraint. @@ -361,6 +375,8 @@ def load_cubes(uris, constraints=None, callback=None): One or more constraints. * callback: A modifier/filter function. + * loader_kwargs (dict): + Additional settings, specific to the file-format loader. Returns: An :class:`iris.cube.CubeList`. Note that there is no inherent order @@ -369,7 +385,9 @@ def load_cubes(uris, constraints=None, callback=None): """ # Merge the incoming cubes - collection = _load_collection(uris, constraints, callback).merged() + collection = _load_collection( + uris, constraints, callback, loader_kwargs + ).merged() # Make sure we have exactly one merged cube per constraint bad_pairs = [pair for pair in collection.pairs if len(pair) != 1] @@ -382,7 +400,7 @@ def load_cubes(uris, constraints=None, callback=None): return collection.cubes() -def load_raw(uris, constraints=None, callback=None): +def load_raw(uris, constraints=None, callback=None, loader_kwargs=None): """ Loads non-merged cubes. @@ -406,6 +424,8 @@ def load_raw(uris, constraints=None, callback=None): One or more constraints. * callback: A modifier/filter function. + * loader_kwargs (dict): + Additional settings, specific to the file-format loader. Returns: An :class:`iris.cube.CubeList`. @@ -414,7 +434,9 @@ def load_raw(uris, constraints=None, callback=None): from iris.fileformats.um._fast_load import _raw_structured_loading with _raw_structured_loading(): - return _load_collection(uris, constraints, callback).cubes() + return _load_collection( + uris, constraints, callback, loader_kwargs + ).cubes() save = iris.io.save diff --git a/lib/iris/io/__init__.py b/lib/iris/io/__init__.py index 034fa4baab..9e4e9fcc76 100644 --- a/lib/iris/io/__init__.py +++ b/lib/iris/io/__init__.py @@ -176,7 +176,7 @@ def expand_filespecs(file_specs): return [fname for fnames in all_expanded for fname in fnames] -def load_files(filenames, callback, constraints=None): +def load_files(filenames, callback, constraints=None, loader_kwargs=None): """ Takes a list of filenames which may also be globs, and optionally a constraint set and a callback function, and returns a @@ -202,19 +202,24 @@ def load_files(filenames, callback, constraints=None): handler_map[handling_format_spec].append(fn) # Call each iris format handler with the approriate filenames + if loader_kwargs is None: + loader_kwargs = {} + for handling_format_spec in sorted(handler_map): fnames = handler_map[handling_format_spec] if handling_format_spec.constraint_aware_handler: for cube in handling_format_spec.handler( - fnames, callback, constraints + fnames, callback, constraints, **loader_kwargs ): yield cube else: - for cube in handling_format_spec.handler(fnames, callback): + for cube in handling_format_spec.handler( + fnames, callback, **loader_kwargs + ): yield cube -def load_http(urls, callback): +def load_http(urls, callback, loader_kwargs=None): """ Takes a list of urls and a callback function, and returns a generator of Cubes from the given URLs. @@ -234,9 +239,14 @@ def load_http(urls, callback): handler_map[handling_format_spec].append(url) # Call each iris format handler with the appropriate filenames + if loader_kwargs is None: + loader_kwargs = {} + for handling_format_spec in sorted(handler_map): fnames = handler_map[handling_format_spec] - for cube in handling_format_spec.handler(fnames, callback): + for cube in handling_format_spec.handler( + fnames, callback, **loader_kwargs + ): yield cube diff --git a/lib/iris/tests/unit/fileformats/test_loader_kwargs.py b/lib/iris/tests/unit/fileformats/test_loader_kwargs.py new file mode 100644 index 0000000000..b50d96ebf5 --- /dev/null +++ b/lib/iris/tests/unit/fileformats/test_loader_kwargs.py @@ -0,0 +1,145 @@ +# Copyright Iris contributors +# +# This file is part of Iris and is released under the LGPL license. +# See COPYING and COPYING.LESSER in the root of the repository for full +# licensing details. +""" +Test iris.load functions support for loader-specific keywords. + +""" + +import os.path +import shutil +import tempfile + +import iris +from iris.cube import Cube, CubeList +from iris.fileformats import FORMAT_AGENT +from iris.io.format_picker import FileExtension, FormatSpecification + +# import iris tests first so that some things can be initialised before +# importing anything else +import iris.tests as tests + + +# Define an extra 'dummy' FormatHandler for load-kwarg testing. +# We add this, temporarily, to the Iris file-format picker for testing. +def format_x_handler_function(fnames, callback, **loader_kwargs): + # A format handler function. + # Yield a single cube, with its attributes set to the call arguments. + call_args = dict( + fnames=fnames, callback=callback, loader_kwargs=loader_kwargs + ) + cube = Cube(1, attributes=call_args) + yield cube + + +# A Format spec for the fake file format. +FORMAT_X_EXTENSION = "._xtst_" +FORMAT_X = FormatSpecification( + "Testing file handler", + FileExtension(), + FORMAT_X_EXTENSION, + format_x_handler_function, +) + + +class LoadFunctionMixin: + # Common code to test load/load_cube/load_cubes/load_raw. + + # Inheritor must set this to the name of an iris load function. + # Note: storing the function itself causes it to be mis-identified as an + # instance method when called, so doing it by name is clearer. + load_function_name = "xxx" + + @classmethod + def setUpClass(cls): + # Add our dummy format handler to the common Iris io file-picker. + FORMAT_AGENT.add_spec(FORMAT_X) + # Create a temporary working directory. + cls._temp_dir = tempfile.mkdtemp() + # Store the path of a dummy file whose name matches the picker. + filename = "testfile" + FORMAT_X_EXTENSION + cls.test_filepath = os.path.join(cls._temp_dir, filename) + # Create the dummy file. + with open(cls.test_filepath, "w") as f_open: + # Write some data to satisfy the other, signature-based pickers. + # TODO: this really shouldn't be necessary ?? + f_open.write("\x00" * 100) + + @classmethod + def tearDownClass(cls): + # Remove the dummy format handler. + # N.B. no public api, so uses a private property of FormatAgent. + FORMAT_AGENT._format_specs.remove(FORMAT_X) + # Delete the temporary directory. + shutil.rmtree(cls._temp_dir) + + def _load_a_cube(self, *args, **kwargs): + load_function = getattr(iris, self.load_function_name) + result = load_function(*args, loader_kwargs=kwargs) + if load_function is not iris.load_cube: + # Handle 'other' load functions, which return CubeLists ... + self.assertIsInstance(result, CubeList) + # ... however, we intend that all uses will return only 1 cube. + self.assertEqual(len(result), 1) + result = result[0] + self.assertIsInstance(result, Cube) + return result + + def test_extra_args(self): + test_kwargs = {"loader_a": 1, "loader_b": "two"} + result = self._load_a_cube(self.test_filepath, **test_kwargs) + self.assertEqual( + result.attributes, + dict( + fnames=[self.test_filepath], + callback=None, + loader_kwargs=test_kwargs, + ), + ) + + def test_no_extra_args(self): + result = self._load_a_cube(self.test_filepath) + self.assertEqual( + result.attributes, + dict(fnames=[self.test_filepath], callback=None, loader_kwargs={}), + ) + + @tests.skip_data + def test_wrong_loader_noargs_ok(self): + filepath = tests.get_data_path( + ["NetCDF", "global", "xyz_t", "GEMS_CO2_Apr2006.nc"] + ) + result = self._load_a_cube(filepath, "co2") + self.assertIsNot(result, None) + + @tests.skip_data + def test_wrong_loader_withargs__fail(self): + filepath = tests.get_data_path( + ["NetCDF", "global", "xyz_t", "GEMS_CO2_Apr2006.nc"] + ) + test_kwargs = {"junk": "this"} + msg = "load.* got an unexpected keyword argument 'junk'" + with self.assertRaisesRegex(TypeError, msg): + _ = self._load_a_cube(filepath, "co2", **test_kwargs) + + +class TestLoad(LoadFunctionMixin, tests.IrisTest): + load_function_name = "load" + + +class TestLoadCubes(LoadFunctionMixin, tests.IrisTest): + load_function_name = "load_cubes" + + +class TestLoadCube(LoadFunctionMixin, tests.IrisTest): + load_function_name = "load_cube" + + +class TestLoadRaw(LoadFunctionMixin, tests.IrisTest): + load_function_name = "load_raw" + + +if __name__ == "__main__": + tests.main()