diff --git a/yt/data_objects/static_output.py b/yt/data_objects/static_output.py index f211606997d..78cf39ff25f 100644 --- a/yt/data_objects/static_output.py +++ b/yt/data_objects/static_output.py @@ -9,6 +9,7 @@ from stat import ST_CTIME import numpy as np +from unyt.exceptions import UnitConversionError, UnitParseError from yt.config import ytcfg from yt.data_objects.particle_filters import filter_registry @@ -215,9 +216,7 @@ def __init__( self.known_filters = self.known_filters or {} self.particle_unions = self.particle_unions or {} self.field_units = self.field_units or {} - if units_override is None: - units_override = {} - self.units_override = units_override + self.units_override = self.__class__._sanitize_units_override(units_override) # path stuff self.parameter_filename = str(filename) @@ -1116,6 +1115,8 @@ def _assign_unit_system(self, unit_system): self.unit_registry.unit_system = self.unit_system def _create_unit_registry(self, unit_system): + from yt.units import dimensions + # yt assumes a CGS unit system by default (for back compat reasons). # Since unyt is MKS by default we specify the MKS values of the base # units in the CGS system. So, for length, 1 cm = .01 m. And so on. @@ -1243,29 +1244,113 @@ def set_code_units(self): "unitary", float(DW.max() * DW.units.base_value), DW.units.dimensions ) + @classmethod + def _validate_units_override_keys(cls, units_override): + valid_keys = set(cls.default_units.keys()) + invalid_keys_found = set(units_override.keys()) - valid_keys + if invalid_keys_found: + raise ValueError( + f"units_override contains invalid keys: {invalid_keys_found}" + ) + + default_units = { + "length_unit": "cm", + "time_unit": "s", + "mass_unit": "g", + "velocity_unit": "cm/s", + "magnetic_unit": "gauss", + "temperature_unit": "K", + } + + @classmethod + def _sanitize_units_override(cls, units_override): + """ + Convert units_override values to valid input types for unyt. + Throw meaningful errors early if units_override is ill-formed. + + Parameters + ---------- + units_override : dict + + keys should be strings with format "_unit" (e.g. "mass_unit"), and + need to match a key in cls.default_units + + values should be mappable to unyt.unyt_quantity objects, and can be any + combinations of: + - unyt.unyt_quantity + - 2-long sequence (tuples, list, ...) with types (number, str) + e.g. (10, "km"), (0.1, "s") + - number (in which case the associated is taken from cls.default_unit) + + + Raises + ------ + TypeError + If unit_override has invalid types + + ValueError + If provided units do not match the intended dimensionality, + or in case of a zero scaling factor. + + """ + uo = {} + if units_override is None: + return uo + + cls._validate_units_override_keys(units_override) + + for key in cls.default_units: + try: + val = units_override[key] + except KeyError: + continue + + # Now attempt to instanciate a unyt.unyt_quantity from val ... + try: + # ... directly (valid if val is a number, or a unyt_quantity) + uo[key] = YTQuantity(val) + continue + except RuntimeError: + # note that unyt.unyt_quantity throws RuntimeError in lieu of TypeError + pass + try: + # ... with tuple unpacking (valid if val is a sequence) + uo[key] = YTQuantity(*val) + continue + except (RuntimeError, TypeError, UnitParseError): + pass + raise TypeError( + "units_override values should be 2-sequence (float, str), " + "YTQuantity objects or real numbers; " + f"received {val} with type {type(val)}." + ) + for key, q in uo.items(): + if q.units.is_dimensionless: + uo[key] = YTQuantity(q, cls.default_units[key]) + try: + uo[key].to(cls.default_units[key]) + except UnitConversionError as err: + raise ValueError( + "Inconsistent dimensionality in units_override. " + f"Received {key} = {uo[key]}" + ) from err + if 1 / uo[key].value == np.inf: + raise ValueError( + f"Invalid 0 normalisation factor in units_override for {key}." + ) + return uo + def _override_code_units(self): - if len(self.units_override) == 0: + if not self.units_override: return + mylog.warning( "Overriding code units: Use this option only if you know that the " "dataset doesn't define the units correctly or at all." ) - for unit, cgs in [ - ("length", "cm"), - ("time", "s"), - ("mass", "g"), - ("velocity", "cm/s"), - ("magnetic", "gauss"), - ("temperature", "K"), - ]: - val = self.units_override.get(f"{unit}_unit", None) - if val is not None: - if isinstance(val, YTQuantity): - val = (val.v, str(val.units)) - elif not isinstance(val, tuple): - val = (val, cgs) - mylog.info("Overriding %s_unit: %g %s.", unit, val[0], val[1]) - setattr(self, f"{unit}_unit", self.quan(val[0], val[1])) + for ukey, val in self.units_override.items(): + mylog.info("Overriding %s: %s.", ukey, val) + setattr(self, ukey, self.quan(val)) _units = None _unit_system_id = None diff --git a/yt/data_objects/tests/test_units_override.py b/yt/data_objects/tests/test_units_override.py new file mode 100644 index 00000000000..c874a4ba433 --- /dev/null +++ b/yt/data_objects/tests/test_units_override.py @@ -0,0 +1,70 @@ +from functools import partial + +from yt.data_objects.static_output import Dataset +from yt.testing import assert_raises +from yt.units import YTQuantity +from yt.units.unit_registry import UnitRegistry + +mock_quan = partial(YTQuantity, registry=UnitRegistry()) + + +def test_schema_validation(): + + valid_schemas = [ + {"length_unit": 1.0}, + {"length_unit": [1.0]}, + {"length_unit": (1.0,)}, + {"length_unit": int(1.0)}, + {"length_unit": (1.0, "m")}, + {"length_unit": [1.0, "m"]}, + {"length_unit": YTQuantity(1.0, "m")}, + ] + + for schema in valid_schemas: + uo = Dataset._sanitize_units_override(schema) + for v in uo.values(): + q = mock_quan(v) # check that no error (TypeError) is raised + q.to("pc") # check that q is a length + + +def test_invalid_schema_detection(): + invalid_key_schemas = [ + {"len_unit": 1.0}, # plain invalid key + {"lenght_unit": 1.0}, # typo + ] + for invalid_schema in invalid_key_schemas: + assert_raises(ValueError, Dataset._sanitize_units_override, invalid_schema) + + invalid_val_schemas = [ + {"length_unit": [1, 1, 1]}, # len(val) > 2 + {"length_unit": [1, 1, 1, 1, 1]}, # "data type not understood" in unyt + ] + + for invalid_schema in invalid_val_schemas: + assert_raises(TypeError, Dataset._sanitize_units_override, invalid_schema) + + # 0 shouldn't make sense + invalid_number_schemas = [ + {"length_unit": 0}, + {"length_unit": [0]}, + {"length_unit": (0,)}, + {"length_unit": (0, "cm")}, + ] + for invalid_schema in invalid_number_schemas: + assert_raises(ValueError, Dataset._sanitize_units_override, invalid_schema) + + +def test_typing_error_detection(): + invalid_schema = {"length_unit": "1m"} + + # this is the error that is raised by unyt on bad input + assert_raises(RuntimeError, mock_quan, invalid_schema["length_unit"]) + + # check that the sanitizer function is able to catch the + # type issue before passing down to unyt + assert_raises(TypeError, Dataset._sanitize_units_override, invalid_schema) + + +def test_dimensionality_error_detection(): + invalid_schema = {"length_unit": YTQuantity(1.0, "s")} + assert_raises(ValueError, Dataset._sanitize_units_override, invalid_schema) diff --git a/yt/frontends/amrvac/data_structures.py b/yt/frontends/amrvac/data_structures.py index 23650b45d68..3934d937174 100644 --- a/yt/frontends/amrvac/data_structures.py +++ b/yt/frontends/amrvac/data_structures.py @@ -4,8 +4,6 @@ """ - - import os import stat import struct @@ -24,14 +22,6 @@ from .datfile_utils import get_header, get_tree_info from .fields import AMRVACFieldInfo -ALLOWED_UNIT_COMBINATIONS = [ - {"numberdensity_unit", "temperature_unit", "length_unit"}, - {"mass_unit", "temperature_unit", "length_unit"}, - {"mass_unit", "time_unit", "length_unit"}, - {"numberdensity_unit", "velocity_unit", "length_unit"}, - {"mass_unit", "velocity_unit", "length_unit"}, -] - class AMRVACGrid(AMRGridPatch): """A class to populate AMRVACHierarchy.grids, setting parent/children relations.""" @@ -86,23 +76,17 @@ def __init__(self, ds, dataset_type="amrvac"): super(AMRVACHierarchy, self).__init__(ds, dataset_type) def _detect_output_fields(self): - """ - Parse field names from datfile header, as stored in self.dataset.parameters - - """ - # required method + """Parse field names from the header, as stored in self.dataset.parameters""" self.field_list = [ (self.dataset_type, f) for f in self.dataset.parameters["w_names"] ] def _count_grids(self): """Set self.num_grids from datfile header.""" - # required method self.num_grids = self.dataset.parameters["nleafs"] def _parse_index(self): """Populate self.grid_* attributes from tree info from datfile header.""" - # required method with open(self.index_filename, "rb") as istream: vaclevels, morton_indices, block_offsets = get_tree_info(istream) assert ( @@ -243,7 +227,6 @@ def __init__( @classmethod def _is_valid(self, *args, **kwargs): """At load time, check whether data is recognized as AMRVAC formatted.""" - # required class method validation = False if args[0].endswith(".dat"): try: @@ -278,7 +261,7 @@ def _parse_geometry(self, geometry_tag): Returns ------- geometry_yt : str - Lower case geometry tag "cartesian", "polar", "cylindrical" or "spherical" + Lower case geometry tag ("cartesian", "polar", "cylindrical" or "spherical") Examples -------- @@ -384,17 +367,9 @@ def _parse_parameter_file(self): # units stuff ====================================================================== def _set_code_unit_attributes(self): """Reproduce how AMRVAC internally set up physical normalisation factors.""" - # required method - # devnote: this method is never defined in the parent abstract class Dataset - # but it is called in Dataset.set_code_units(), which is part of - # Dataset.__init__() so it must be defined here. - - # devnote: this gets called later than Dataset._override_code_units() + # This gets called later than Dataset._override_code_units() # This is the reason why it uses setdefaultattr: it will only fill in the gaps # left by the "override", instead of overriding them again. - # For the same reason, self.units_override is set, as well as corresponding - # *_unit instance attributes which may include up to 3 of the following items: - # length, time, mass, velocity, number_density, temperature # note: yt sets hydrogen mass equal to proton mass, amrvac doesn't. mp_cgs = self.quan(1.672621898e-24, "g") # This value is taken from AstroPy @@ -408,23 +383,17 @@ def _set_code_unit_attributes(self): # in this case unit_mass is supplied (and has been set as attribute) mass_unit = self.mass_unit density_unit = mass_unit / length_unit ** 3 - numberdensity_unit = density_unit / ((1.0 + 4.0 * He_abundance) * mp_cgs) + nd_unit = density_unit / ((1.0 + 4.0 * He_abundance) * mp_cgs) else: # other case: numberdensity is supplied. # Fall back to one (default) if no overrides supplied - numberdensity_override = self.units_override.get( - "numberdensity_unit", (1, "cm**-3") - ) - if ( - "numberdensity_unit" in self.units_override - ): # print similar warning as yt when overriding numberdensity - mylog.info( - "Overriding numberdensity_unit: %g %s.", *numberdensity_override + try: + nd_unit = self.quan(self.units_override["numberdensity_unit"]) + except KeyError: + nd_unit = self.quan( + 1.0, self.__class__.default_units["numberdensity_unit"] ) - numberdensity_unit = self.quan( - *numberdensity_override - ) # numberdensity is never set as attribute - density_unit = (1.0 + 4.0 * He_abundance) * mp_cgs * numberdensity_unit + density_unit = (1.0 + 4.0 * He_abundance) * mp_cgs * nd_unit mass_unit = density_unit * length_unit ** 3 # 2. calculations for velocity @@ -442,18 +411,14 @@ def _set_code_unit_attributes(self): # Fall back to one (default) if not temperature_unit = getattr(self, "temperature_unit", self.quan(1, "K")) pressure_unit = ( - (2.0 + 3.0 * He_abundance) - * numberdensity_unit - * kb_cgs - * temperature_unit + (2.0 + 3.0 * He_abundance) * nd_unit * kb_cgs * temperature_unit ).in_cgs() velocity_unit = (np.sqrt(pressure_unit / density_unit)).in_cgs() else: # velocity is not zero if either time was given OR velocity was given pressure_unit = (density_unit * velocity_unit ** 2).in_cgs() temperature_unit = ( - pressure_unit - / ((2.0 + 3.0 * He_abundance) * numberdensity_unit * kb_cgs) + pressure_unit / ((2.0 + 3.0 * He_abundance) * nd_unit * kb_cgs) ).in_cgs() # 4. calculations for magnetic unit and time @@ -464,7 +429,6 @@ def _set_code_unit_attributes(self): setdefaultattr(self, "mass_unit", mass_unit) setdefaultattr(self, "density_unit", density_unit) - setdefaultattr(self, "numberdensity_unit", numberdensity_unit) setdefaultattr(self, "length_unit", length_unit) setdefaultattr(self, "velocity_unit", velocity_unit) @@ -474,48 +438,61 @@ def _set_code_unit_attributes(self): setdefaultattr(self, "pressure_unit", pressure_unit) setdefaultattr(self, "magnetic_unit", magnetic_unit) - def _override_code_units(self): - """Add a check step to the base class' method (Dataset).""" - self._check_override_consistency() - super(AMRVACDataset, self)._override_code_units() - - def _check_override_consistency(self): - """Check that keys in units_override are consistent with respect to AMRVAC's - internal way to set up normalisations factors. + allowed_unit_combinations = [ + {"numberdensity_unit", "temperature_unit", "length_unit"}, + {"mass_unit", "temperature_unit", "length_unit"}, + {"mass_unit", "time_unit", "length_unit"}, + {"numberdensity_unit", "velocity_unit", "length_unit"}, + {"mass_unit", "velocity_unit", "length_unit"}, + ] + + default_units = { + "length_unit": "cm", + "time_unit": "s", + "mass_unit": "g", + "velocity_unit": "cm/s", + "magnetic_unit": "gauss", + "temperature_unit": "K", + # this is the one difference with Dataset.default_units: + # we accept numberdensity_unit as a valid override + "numberdensity_unit": "cm**-3", + } + @classmethod + def _validate_units_override_keys(cls, units_override): + """Check that keys in units_override are consistent with AMRVAC's internal + normalisations factors. """ - # frontend specific method # YT supports overriding other normalisations, this method ensures consistency # between supplied 'units_override' items and those used by AMRVAC. # AMRVAC's normalisations/units have 3 degrees of freedom. # Moreover, if temperature unit is specified then velocity unit will be # calculated accordingly, and vice-versa. - # We replicate this by allowing a finite set of combinations in units_override - if not self.units_override: - return - overrides = set(self.units_override) + # We replicate this by allowing a finite set of combinations. # there are only three degrees of freedom, so explicitly check for this - if len(overrides) > 3: + if len(units_override) > 3: raise ValueError( "More than 3 degrees of freedom were specified " - "in units_override ({} given)".format(len(overrides)) + f"in units_override ({len(units_override)} given)" ) # temperature and velocity cannot both be specified - if "temperature_unit" in overrides and "velocity_unit" in overrides: + if "temperature_unit" in units_override and "velocity_unit" in units_override: raise ValueError( "Either temperature or velocity is allowed in units_override, not both." ) # check if provided overrides are allowed - for allowed_combo in ALLOWED_UNIT_COMBINATIONS: - if overrides.issubset(allowed_combo): + suo = set(units_override) + for allowed_combo in cls.allowed_unit_combinations: + if suo.issubset(allowed_combo): break else: raise ValueError( - "Combination {} passed to units_override " - "is not consistent with AMRVAC. \n" - "Allowed combinations are {}".format( - overrides, ALLOWED_UNIT_COMBINATIONS - ) + f"Combination {suo} passed to units_override " + "is not consistent with AMRVAC.\n" + f"Allowed combinations are {cls.allowed_unit_combinations}" ) + + # syntax for mixing super with classmethod is weird... + super(cls, cls)._validate_units_override_keys(units_override) diff --git a/yt/frontends/amrvac/tests/test_outputs.py b/yt/frontends/amrvac/tests/test_outputs.py index 0aae58e6bb9..0e7f89747d8 100644 --- a/yt/frontends/amrvac/tests/test_outputs.py +++ b/yt/frontends/amrvac/tests/test_outputs.py @@ -2,8 +2,8 @@ import yt # NOQA from yt.frontends.amrvac.api import AMRVACDataset, AMRVACGrid -from yt.testing import assert_allclose_units, assert_raises, requires_file -from yt.units import YTQuantity +from yt.testing import requires_file +from yt.units import YTArray from yt.utilities.answer_testing.framework import ( data_dir_load, requires_ds, @@ -59,8 +59,8 @@ def test_grid_attributes(): assert ds.index.max_level == 2 for g in grids: assert isinstance(g, AMRVACGrid) - assert isinstance(g.LeftEdge, yt.units.yt_array.YTArray) - assert isinstance(g.RightEdge, yt.units.yt_array.YTArray) + assert isinstance(g.LeftEdge, YTArray) + assert isinstance(g.RightEdge, YTArray) assert isinstance(g.ActiveDimensions, np.ndarray) assert isinstance(g.Level, (np.int32, np.int64, int)) @@ -136,126 +136,3 @@ def test_rmi_cartesian_dust_2D(): for test in small_patch_amr(ds, _get_fields_to_check(ds)): test_rmi_cartesian_dust_2D.__name__ = test.description yield test - - -# Tests for units: verify that overriding certain units yields the correct derived units -# The following are correct normalisations based on length, numberdensity and temp -length_unit = (1e9, "cm") -numberdensity_unit = (1e9, "cm**-3") -temperature_unit = (1e6, "K") -density_unit = (2.341670657200000e-15, "g*cm**-3") -mass_unit = (2.341670657200000e12, "g") -velocity_unit = (1.164508387441102e07, "cm*s**-1") -pressure_unit = (3.175492240000000e-01, "dyn*cm**-2") -time_unit = (8.587314705370271e01, "s") -magnetic_unit = (1.997608879907716, "gauss") - - -def _assert_normalisations_equal(ds): - assert_allclose_units(ds.length_unit, YTQuantity(*length_unit)) - assert_allclose_units(ds.numberdensity_unit, YTQuantity(*numberdensity_unit)) - assert_allclose_units(ds.temperature_unit, YTQuantity(*temperature_unit)) - assert_allclose_units(ds.density_unit, YTQuantity(*density_unit)) - assert_allclose_units(ds.mass_unit, YTQuantity(*mass_unit)) - assert_allclose_units(ds.velocity_unit, YTQuantity(*velocity_unit)) - assert_allclose_units(ds.pressure_unit, YTQuantity(*pressure_unit)) - assert_allclose_units(ds.time_unit, YTQuantity(*time_unit)) - assert_allclose_units(ds.magnetic_unit, YTQuantity(*magnetic_unit)) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_length_temp_nb(): - # overriding length, temperature, numberdensity - overrides = dict( - length_unit=length_unit, - temperature_unit=temperature_unit, - numberdensity_unit=numberdensity_unit, - ) - ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) - _assert_normalisations_equal(ds) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_length_temp_mass(): - # overriding length, temperature, mass - overrides = dict( - length_unit=length_unit, temperature_unit=temperature_unit, mass_unit=mass_unit - ) - ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) - _assert_normalisations_equal(ds) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_length_time_mass(): - # overriding length, time, mass - overrides = dict(length_unit=length_unit, time_unit=time_unit, mass_unit=mass_unit) - ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) - _assert_normalisations_equal(ds) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_length_vel_nb(): - # overriding length, velocity, numberdensity - overrides = dict( - length_unit=length_unit, - velocity_unit=velocity_unit, - numberdensity_unit=numberdensity_unit, - ) - ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) - _assert_normalisations_equal(ds) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_length_vel_mass(): - # overriding length, velocity, mass - overrides = dict( - length_unit=length_unit, velocity_unit=velocity_unit, mass_unit=mass_unit - ) - ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) - _assert_normalisations_equal(ds) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_default(): - # test default normalisations, without overrides - ds = data_dir_load(khi_cartesian_2D) - assert_allclose_units(ds.length_unit, YTQuantity(1, "cm")) - assert_allclose_units(ds.numberdensity_unit, YTQuantity(1, "cm**-3")) - assert_allclose_units(ds.temperature_unit, YTQuantity(1, "K")) - assert_allclose_units( - ds.density_unit, YTQuantity(2.341670657200000e-24, "g*cm**-3") - ) - assert_allclose_units(ds.mass_unit, YTQuantity(2.341670657200000e-24, "g")) - assert_allclose_units( - ds.velocity_unit, YTQuantity(1.164508387441102e04, "cm*s**-1") - ) - assert_allclose_units( - ds.pressure_unit, YTQuantity(3.175492240000000e-16, "dyn*cm**-2") - ) - assert_allclose_units(ds.time_unit, YTQuantity(8.587314705370271e-05, "s")) - assert_allclose_units(ds.magnetic_unit, YTQuantity(6.316993934686148e-08, "gauss")) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_too_many_args(): - # test forbidden case: too many arguments (max 3 are allowed) - overrides = dict( - length_unit=length_unit, - numberdensity_unit=numberdensity_unit, - temperature_unit=temperature_unit, - time_unit=time_unit, - ) - with assert_raises(ValueError): - data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) - - -@requires_file(khi_cartesian_2D) -def test_normalisations_vel_and_length(): - # test forbidden case: both velocity and temperature are specified as overrides - overrides = dict( - length_unit=length_unit, - velocity_unit=velocity_unit, - temperature_unit=temperature_unit, - ) - with assert_raises(ValueError): - data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) diff --git a/yt/frontends/amrvac/tests/test_units_override.py b/yt/frontends/amrvac/tests/test_units_override.py new file mode 100644 index 00000000000..3d3074cdadb --- /dev/null +++ b/yt/frontends/amrvac/tests/test_units_override.py @@ -0,0 +1,126 @@ +from yt.testing import assert_allclose_units, assert_raises, requires_file +from yt.units import YTQuantity +from yt.utilities.answer_testing.framework import data_dir_load + +khi_cartesian_2D = "amrvac/kh_2d0000.dat" + +# Tests for units: check that overriding certain units yields the correct derived units. +# The following are the correct normalisations +# based on length, numberdensity and temperature +length_unit = (1e9, "cm") +numberdensity_unit = (1e9, "cm**-3") +temperature_unit = (1e6, "K") +density_unit = (2.341670657200000e-15, "g*cm**-3") +mass_unit = (2.341670657200000e12, "g") +velocity_unit = (1.164508387441102e07, "cm*s**-1") +pressure_unit = (3.175492240000000e-01, "dyn*cm**-2") +time_unit = (8.587314705370271e01, "s") +magnetic_unit = (1.997608879907716, "gauss") + + +def _assert_normalisations_equal(ds): + assert_allclose_units(ds.length_unit, YTQuantity(*length_unit)) + assert_allclose_units(ds.temperature_unit, YTQuantity(*temperature_unit)) + assert_allclose_units(ds.density_unit, YTQuantity(*density_unit)) + assert_allclose_units(ds.mass_unit, YTQuantity(*mass_unit)) + assert_allclose_units(ds.velocity_unit, YTQuantity(*velocity_unit)) + assert_allclose_units(ds.pressure_unit, YTQuantity(*pressure_unit)) + assert_allclose_units(ds.time_unit, YTQuantity(*time_unit)) + assert_allclose_units(ds.magnetic_unit, YTQuantity(*magnetic_unit)) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_length_temp_nb(): + # overriding length, temperature, numberdensity + overrides = dict( + length_unit=length_unit, + temperature_unit=temperature_unit, + numberdensity_unit=numberdensity_unit, + ) + ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) + _assert_normalisations_equal(ds) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_length_temp_mass(): + # overriding length, temperature, mass + overrides = dict( + length_unit=length_unit, temperature_unit=temperature_unit, mass_unit=mass_unit + ) + ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) + _assert_normalisations_equal(ds) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_length_time_mass(): + # overriding length, time, mass + overrides = dict(length_unit=length_unit, time_unit=time_unit, mass_unit=mass_unit) + ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) + _assert_normalisations_equal(ds) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_length_vel_nb(): + # overriding length, velocity, numberdensity + overrides = dict( + length_unit=length_unit, + velocity_unit=velocity_unit, + numberdensity_unit=numberdensity_unit, + ) + ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) + _assert_normalisations_equal(ds) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_length_vel_mass(): + # overriding length, velocity, mass + overrides = dict( + length_unit=length_unit, velocity_unit=velocity_unit, mass_unit=mass_unit + ) + ds = data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) + _assert_normalisations_equal(ds) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_default(): + # test default normalisations, without overrides + ds = data_dir_load(khi_cartesian_2D) + assert_allclose_units(ds.length_unit, YTQuantity(1, "cm")) + assert_allclose_units(ds.temperature_unit, YTQuantity(1, "K")) + assert_allclose_units( + ds.density_unit, YTQuantity(2.341670657200000e-24, "g*cm**-3") + ) + assert_allclose_units(ds.mass_unit, YTQuantity(2.341670657200000e-24, "g")) + assert_allclose_units( + ds.velocity_unit, YTQuantity(1.164508387441102e04, "cm*s**-1") + ) + assert_allclose_units( + ds.pressure_unit, YTQuantity(3.175492240000000e-16, "dyn*cm**-2") + ) + assert_allclose_units(ds.time_unit, YTQuantity(8.587314705370271e-05, "s")) + assert_allclose_units(ds.magnetic_unit, YTQuantity(6.316993934686148e-08, "gauss")) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_too_many_args(): + # test forbidden case: too many arguments (max 3 are allowed) + overrides = dict( + length_unit=length_unit, + numberdensity_unit=numberdensity_unit, + temperature_unit=temperature_unit, + time_unit=time_unit, + ) + with assert_raises(ValueError): + data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) + + +@requires_file(khi_cartesian_2D) +def test_normalisations_vel_and_length(): + # test forbidden case: both velocity and temperature are specified as overrides + overrides = dict( + length_unit=length_unit, + velocity_unit=velocity_unit, + temperature_unit=temperature_unit, + ) + with assert_raises(ValueError): + data_dir_load(khi_cartesian_2D, kwargs={"units_override": overrides}) diff --git a/yt/testing.py b/yt/testing.py index caced3e68e2..0ad6081c840 100644 --- a/yt/testing.py +++ b/yt/testing.py @@ -867,7 +867,7 @@ def units_override_check(fn): unit_attr = getattr(ds1, f"{u}_unit", None) if unit_attr is not None: attrs1.append(unit_attr) - units_override[f"{u}_unit"] = (unit_attr.v, str(unit_attr.units)) + units_override[f"{u}_unit"] = (unit_attr.v, unit_attr.units) del ds1 ds2 = load(fn, units_override=units_override) assert len(ds2.units_override) > 0