diff --git a/iris_grib/tests/__init__.py b/iris_grib/tests/__init__.py index 897feaf5..2d1c1681 100644 --- a/iris_grib/tests/__init__.py +++ b/iris_grib/tests/__init__.py @@ -17,6 +17,8 @@ import os import os.path +import numpy as np + try: from iris.tests import IrisTest_nometa as IrisTest except ImportError: @@ -24,6 +26,8 @@ from iris.tests import main, skip_data, get_data_path +from iris_grib.message import GribMessage + #: Basepath for iris-grib test results. _RESULT_PATH = os.path.join(os.path.dirname(__file__), 'results') @@ -92,3 +96,113 @@ def get_testdata_path(relative_path): if not isinstance(relative_path, str): relative_path = os.path.join(*relative_path) return os.path.abspath(os.path.join(_TESTDATA_PATH, relative_path)) + + +class TestGribMessage(IrisGribTest): + def assertGribMessageContents(self, filename, contents): + """ + Evaluate whether all messages in a GRIB2 file contain the provided + contents. + + * filename (string) + The path on disk of an existing GRIB file + + * contents + An iterable of GRIB message keys and expected values. + + """ + messages = GribMessage.messages_from_filename(filename) + for message in messages: + for element in contents: + section, key, val = element + self.assertEqual(message.sections[section][key], val) + + def assertGribMessageDifference( + self, filename1, filename2, diffs, skip_keys=(), skip_sections=() + ): + """ + Evaluate that the two messages only differ in the ways specified. + + * filename[0|1] (string) + The path on disk of existing GRIB files + + * diffs + An dictionary of GRIB message keys and expected diff values: + {key: (m1val, m2val),...} . + + * skip_keys + An iterable of key names to ignore during comparison. + + * skip_sections + An iterable of section numbers to ignore during comparison. + + """ + messages1 = list(GribMessage.messages_from_filename(filename1)) + messages2 = list(GribMessage.messages_from_filename(filename2)) + self.assertEqual(len(messages1), len(messages2)) + for m1, m2 in zip(messages1, messages2): + m1_sect = set(m1.sections.keys()) + m2_sect = set(m2.sections.keys()) + + for missing_section in m1_sect ^ m2_sect: + what = ( + "introduced" if missing_section in m1_sect else "removed" + ) + # Assert that an introduced section is in the diffs. + self.assertIn( + missing_section, + skip_sections, + msg="Section {} {}".format(missing_section, what), + ) + + for section in m1_sect & m2_sect: + # For each section, check that the differences are + # known diffs. + m1_keys = set(m1.sections[section]._keys) + m2_keys = set(m2.sections[section]._keys) + + difference = m1_keys ^ m2_keys + unexpected_differences = difference - set(skip_keys) + if unexpected_differences: + self.fail( + "There were keys in section {} which \n" + "weren't in both messages and which weren't " + "skipped.\n{}" + "".format(section, ", ".join(unexpected_differences)) + ) + + keys_to_compare = m1_keys & m2_keys - set(skip_keys) + + for key in keys_to_compare: + m1_value = m1.sections[section][key] + m2_value = m2.sections[section][key] + msg = "{} {} != {}" + if key not in diffs: + # We have a key which we expect to be the same for + # both messages. + if isinstance(m1_value, np.ndarray): + # A large tolerance appears to be required for + # gribapi 1.12, but not for 1.14. + self.assertArrayAlmostEqual( + m1_value, m2_value, decimal=2 + ) + else: + self.assertEqual( + m1_value, + m2_value, + msg=msg.format(key, m1_value, m2_value), + ) + else: + # We have a key which we expect to be different + # for each message. + self.assertEqual( + m1_value, + diffs[key][0], + msg=msg.format(key, m1_value, diffs[key][0]), + ) + + self.assertEqual( + m2_value, + diffs[key][1], + msg=msg.format(key, m2_value, diffs[key][1]), + ) diff --git a/iris_grib/tests/integration/round_trip/test_grid_definition_section.py b/iris_grib/tests/integration/round_trip/test_grid_definition_section.py index a7bcc122..7a3c508c 100644 --- a/iris_grib/tests/integration/round_trip/test_grid_definition_section.py +++ b/iris_grib/tests/integration/round_trip/test_grid_definition_section.py @@ -17,12 +17,10 @@ from iris.fileformats.pp import EARTH_RADIUS as UM_DEFAULT_EARTH_RADIUS from iris.util import is_regular -from iris.tests import TestGribMessage - from iris_grib.grib_phenom_translation import GRIBCode -class TestGDT5(TestGribMessage): +class TestGDT5(tests.TestGribMessage): def test_save_load(self): # Load sample UKV data (variable-resolution rotated grid). path = tests.get_data_path(("PP", "ukV1", "ukVpmslont.pp")) diff --git a/iris_grib/tests/integration/round_trip/test_product_definition_section.py b/iris_grib/tests/integration/round_trip/test_product_definition_section.py index b0fed1cc..72721380 100644 --- a/iris_grib/tests/integration/round_trip/test_product_definition_section.py +++ b/iris_grib/tests/integration/round_trip/test_product_definition_section.py @@ -18,13 +18,12 @@ import iris.coords import iris.coord_systems -from iris.tests import TestGribMessage import iris.tests.stock as stock from iris_grib.grib_phenom_translation import GRIBCode -class TestPDT11(TestGribMessage): +class TestPDT11(tests.TestGribMessage): def test_perturbation(self): path = tests.get_data_path( ("NetCDF", "global", "xyt", "SMALL_hires_wind_u_for_ipcc4.nc") diff --git a/iris_grib/tests/integration/save_rules/test_grib_save.py b/iris_grib/tests/integration/save_rules/test_grib_save.py index 9cbb7177..4cf3665b 100644 --- a/iris_grib/tests/integration/save_rules/test_grib_save.py +++ b/iris_grib/tests/integration/save_rules/test_grib_save.py @@ -26,13 +26,11 @@ import iris.exceptions import iris.util -from iris.tests import TestGribMessage - import gribapi from iris_grib._load_convert import _MDI as MDI -class TestLoadSave(TestGribMessage): +class TestLoadSave(tests.TestGribMessage): def setUp(self): self.skip_keys = []