Skip to content

Commit

Permalink
Merge pull request #2785 from neutrinoceros/units_override_sanitizing
Browse files Browse the repository at this point in the history
Units override sanitizing
  • Loading branch information
munkm authored Sep 18, 2020
2 parents 0bcbd98 + a341f82 commit ff56ee1
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 219 deletions.
125 changes: 105 additions & 20 deletions yt/data_objects/static_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from stat import ST_CTIME

import numpy as np
from unyt.exceptions import UnitConversionError, UnitParseError

from yt.config import ytcfg
from yt.data_objects.particle_filters import filter_registry
Expand Down Expand Up @@ -215,9 +216,7 @@ def __init__(
self.known_filters = self.known_filters or {}
self.particle_unions = self.particle_unions or {}
self.field_units = self.field_units or {}
if units_override is None:
units_override = {}
self.units_override = units_override
self.units_override = self.__class__._sanitize_units_override(units_override)

# path stuff
self.parameter_filename = str(filename)
Expand Down Expand Up @@ -1116,6 +1115,8 @@ def _assign_unit_system(self, unit_system):
self.unit_registry.unit_system = self.unit_system

def _create_unit_registry(self, unit_system):
from yt.units import dimensions

# yt assumes a CGS unit system by default (for back compat reasons).
# Since unyt is MKS by default we specify the MKS values of the base
# units in the CGS system. So, for length, 1 cm = .01 m. And so on.
Expand Down Expand Up @@ -1243,29 +1244,113 @@ def set_code_units(self):
"unitary", float(DW.max() * DW.units.base_value), DW.units.dimensions
)

@classmethod
def _validate_units_override_keys(cls, units_override):
valid_keys = set(cls.default_units.keys())
invalid_keys_found = set(units_override.keys()) - valid_keys
if invalid_keys_found:
raise ValueError(
f"units_override contains invalid keys: {invalid_keys_found}"
)

default_units = {
"length_unit": "cm",
"time_unit": "s",
"mass_unit": "g",
"velocity_unit": "cm/s",
"magnetic_unit": "gauss",
"temperature_unit": "K",
}

@classmethod
def _sanitize_units_override(cls, units_override):
"""
Convert units_override values to valid input types for unyt.
Throw meaningful errors early if units_override is ill-formed.
Parameters
----------
units_override : dict
keys should be strings with format "<dim>_unit" (e.g. "mass_unit"), and
need to match a key in cls.default_units
values should be mappable to unyt.unyt_quantity objects, and can be any
combinations of:
- unyt.unyt_quantity
- 2-long sequence (tuples, list, ...) with types (number, str)
e.g. (10, "km"), (0.1, "s")
- number (in which case the associated is taken from cls.default_unit)
Raises
------
TypeError
If unit_override has invalid types
ValueError
If provided units do not match the intended dimensionality,
or in case of a zero scaling factor.
"""
uo = {}
if units_override is None:
return uo

cls._validate_units_override_keys(units_override)

for key in cls.default_units:
try:
val = units_override[key]
except KeyError:
continue

# Now attempt to instanciate a unyt.unyt_quantity from val ...
try:
# ... directly (valid if val is a number, or a unyt_quantity)
uo[key] = YTQuantity(val)
continue
except RuntimeError:
# note that unyt.unyt_quantity throws RuntimeError in lieu of TypeError
pass
try:
# ... with tuple unpacking (valid if val is a sequence)
uo[key] = YTQuantity(*val)
continue
except (RuntimeError, TypeError, UnitParseError):
pass
raise TypeError(
"units_override values should be 2-sequence (float, str), "
"YTQuantity objects or real numbers; "
f"received {val} with type {type(val)}."
)
for key, q in uo.items():
if q.units.is_dimensionless:
uo[key] = YTQuantity(q, cls.default_units[key])
try:
uo[key].to(cls.default_units[key])
except UnitConversionError as err:
raise ValueError(
"Inconsistent dimensionality in units_override. "
f"Received {key} = {uo[key]}"
) from err
if 1 / uo[key].value == np.inf:
raise ValueError(
f"Invalid 0 normalisation factor in units_override for {key}."
)
return uo

def _override_code_units(self):
if len(self.units_override) == 0:
if not self.units_override:
return

mylog.warning(
"Overriding code units: Use this option only if you know that the "
"dataset doesn't define the units correctly or at all."
)
for unit, cgs in [
("length", "cm"),
("time", "s"),
("mass", "g"),
("velocity", "cm/s"),
("magnetic", "gauss"),
("temperature", "K"),
]:
val = self.units_override.get(f"{unit}_unit", None)
if val is not None:
if isinstance(val, YTQuantity):
val = (val.v, str(val.units))
elif not isinstance(val, tuple):
val = (val, cgs)
mylog.info("Overriding %s_unit: %g %s.", unit, val[0], val[1])
setattr(self, f"{unit}_unit", self.quan(val[0], val[1]))
for ukey, val in self.units_override.items():
mylog.info("Overriding %s: %s.", ukey, val)
setattr(self, ukey, self.quan(val))

_units = None
_unit_system_id = None
Expand Down
70 changes: 70 additions & 0 deletions yt/data_objects/tests/test_units_override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from functools import partial

from yt.data_objects.static_output import Dataset
from yt.testing import assert_raises
from yt.units import YTQuantity
from yt.units.unit_registry import UnitRegistry

mock_quan = partial(YTQuantity, registry=UnitRegistry())


def test_schema_validation():

valid_schemas = [
{"length_unit": 1.0},
{"length_unit": [1.0]},
{"length_unit": (1.0,)},
{"length_unit": int(1.0)},
{"length_unit": (1.0, "m")},
{"length_unit": [1.0, "m"]},
{"length_unit": YTQuantity(1.0, "m")},
]

for schema in valid_schemas:
uo = Dataset._sanitize_units_override(schema)
for v in uo.values():
q = mock_quan(v) # check that no error (TypeError) is raised
q.to("pc") # check that q is a length


def test_invalid_schema_detection():
invalid_key_schemas = [
{"len_unit": 1.0}, # plain invalid key
{"lenght_unit": 1.0}, # typo
]
for invalid_schema in invalid_key_schemas:
assert_raises(ValueError, Dataset._sanitize_units_override, invalid_schema)

invalid_val_schemas = [
{"length_unit": [1, 1, 1]}, # len(val) > 2
{"length_unit": [1, 1, 1, 1, 1]}, # "data type not understood" in unyt
]

for invalid_schema in invalid_val_schemas:
assert_raises(TypeError, Dataset._sanitize_units_override, invalid_schema)

# 0 shouldn't make sense
invalid_number_schemas = [
{"length_unit": 0},
{"length_unit": [0]},
{"length_unit": (0,)},
{"length_unit": (0, "cm")},
]
for invalid_schema in invalid_number_schemas:
assert_raises(ValueError, Dataset._sanitize_units_override, invalid_schema)


def test_typing_error_detection():
invalid_schema = {"length_unit": "1m"}

# this is the error that is raised by unyt on bad input
assert_raises(RuntimeError, mock_quan, invalid_schema["length_unit"])

# check that the sanitizer function is able to catch the
# type issue before passing down to unyt
assert_raises(TypeError, Dataset._sanitize_units_override, invalid_schema)


def test_dimensionality_error_detection():
invalid_schema = {"length_unit": YTQuantity(1.0, "s")}
assert_raises(ValueError, Dataset._sanitize_units_override, invalid_schema)
Loading

0 comments on commit ff56ee1

Please sign in to comment.