From 044d59adb4fcb277c928f291f7feca277342329e Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Wed, 20 Oct 2021 19:07:32 -0400 Subject: [PATCH 01/50] draft implementation of @expects --- pint_xarray/__init__.py | 1 + pint_xarray/checking.py | 118 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 pint_xarray/checking.py diff --git a/pint_xarray/__init__.py b/pint_xarray/__init__.py index b854e5bd..b5a702f6 100644 --- a/pint_xarray/__init__.py +++ b/pint_xarray/__init__.py @@ -8,6 +8,7 @@ from . import accessors, formatting, testing # noqa: F401 from .accessors import default_registry as unit_registry from .accessors import setup_registry +from .checking import expects # noqa: F401 try: __version__ = version("pint-xarray") diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py new file mode 100644 index 00000000..0f694201 --- /dev/null +++ b/pint_xarray/checking.py @@ -0,0 +1,118 @@ +import functools + +from pint import Quantity +from xarray import DataArray + +from .accessors import PintDataArrayAccessor + + +def expects(*args_units, return_units=None, **kwargs_units): + """ + Decorator which checks the inputs and outputs of the decorated function have certain units. + + Arguments + + Note that the coordinates of input DataArrays are not checked, only the data. + So if your decorated function uses coordinates and you wish to check their units, + you should pass the coordinates of interest as separate arguments. + + Parameters + ---------- + func: function + Decorated function, which accepts zero or more xarray.DataArrays or pint.Quantitys as inputs, + and may optionally return one or more xarray.DataArrays or pint.Quantitys. + args_units : Union[str, pint.Unit, None] + Unit to expect for each positional argument given to func. + + A value of None indicates not to check that argument for units (suitable for flags + and other non-data arguments). + return_units : Union[Union[str, pint.Unit, None, False], Sequence[Union[str, pint.Unit, None]], Optional + The expected units of the returned value(s), either as a single unit or as an iterable of units. + + A value of None indicates not to check that return value for units (suitable for flags and other + non-data arguments). Passing False means that no return value is expected from the function at all, + and an error will be raised if a return value is found. + kwargs_units : Dict[str, Union[str, pint.Unit, None]], Optional + Unit to expect for each keyword argument given to func. + + A value of None indicates not to check that argument for units (suitable for flags + and other non-data arguments). + + Returns + ------- + return_values + Return values of the wrapped function, either a single value or a tuple of values. + + Raises + ------ + TypeError + If an argument or return value has a specified unit, but is not an xarray.DataArray. + + + Examples + -------- + + Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean). + + >>> @expects('deg C') + ... def above_freezing(temp): + ... return temp > 0 + ... + + + TODO: example where we check units of an optional weighted kwarg + """ + + # TODO: Check args_units, kwargs_units, and return_units types + # TODO: Check number of arguments line up + + def _expects_decorator(func): + + @functools.wraps(func) + def _unit_checking_wrapper(*args, **kwargs): + + converted_args = [] + for arg, arg_unit in zip(args, args_units): + converted_arg = _check_then_convert_to(arg, arg_unit) + converted_args.append(converted_arg) + + converted_kwargs = {} + for key, val in kwargs.items(): + kwarg_unit = kwargs_units[key] + converted_kwargs[key] = _check_then_convert_to(val, kwarg_unit) + + results = func(*converted_args, **converted_kwargs) + + if results is not None: + if return_units is False: + raise ValueError("Did not expect function to return anything") + elif return_units is not None: + # TODO check something was actually returned + # TODO check same number of things were returned as expected + # TODO handle single return value vs tuple of return values + + converted_results = [] + for return_unit, return_value in zip(return_units, results): + converted_result = _check_then_convert_to(return_value, return_unit) + converted_results.append(converted_result) + + return tuple(converted_results) + else: + return results + else: + if return_units: + raise ValueError("Expected function to return something") + + return _unit_checking_wrapper + + return _expects_decorator + + +def _check_then_convert_to(obj, units): + if isinstance(obj, Quantity): + return obj.to(units) + elif isinstance(obj, DataArray): + return obj.pint.to(units) + else: + raise TypeError("Can only expect units for arguments of type xarray.DataArray or pint.Quantity," + f"not {type(obj)}") From 0754f2234aba37feb0d22c5183610918862d9c92 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Wed, 20 Oct 2021 19:07:52 -0400 Subject: [PATCH 02/50] sketch of different tests needed --- pint_xarray/tests/test_checking.py | 97 ++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 pint_xarray/tests/test_checking.py diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py new file mode 100644 index 00000000..4b9c7954 --- /dev/null +++ b/pint_xarray/tests/test_checking.py @@ -0,0 +1,97 @@ +import pytest +import pint +import xarray as xr + +from pint import UnitRegistry + +from ..checking import expects + +ureg = UnitRegistry() + + +class TestExpects: + def test_single_arg(self): + + @expects('degC') + def above_freezing(temp : pint.Quantity): + return temp.magnitude > 0 + + f_q = pint.Quantity(20, units='degF') + assert above_freezing(f_q) == False + + c_q = pint.Quantity(-2, units='degC') + assert above_freezing(c_q) == False + + @expects('degC') + def above_freezing(temp : xr.DataArray): + return temp.pint.magnitude > 0 + + f_da = xr.DataArray(20).pint.quantify(units='degF') + assert above_freezing(f_da) == False + + c_da = xr.DataArray(-2).pint.quantify(units='degC') + assert above_freezing(c_da) == False + + def test_single_kwarg(self): + + @expects('meters', c='meters / second') + def freq(wavelength, c=None): + if c is None: + c = ureg.speed_of_light + + return c / wavelength + + def test_single_return_value(self): + + @expects('Hz') + def period(freq): + return 1 / freq + + f = pint.Quantity(10, units='Hz') + + # test conversion + T = period(f) + assert f.units == 'seconds' + + # test wrong dimensions for conversion + ... + + @pytest.mark.xfail + def test_multiple_return_values(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_mixed_args_kwargs_return_values(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_invalid_input_types(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_invalid_return_types(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_unquantified_arrays(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_wrong_number_of_args(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_nonexistent_kwarg(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_unexpected_return_value(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_expected_return_value(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_wrong_number_of_return_values(self): + raise NotImplementedError From e879ef9c592900741fc4345bd88e3fbf45879600 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 22 Oct 2021 11:51:25 -0400 Subject: [PATCH 03/50] idea for test --- pint_xarray/tests/test_checking.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 4b9c7954..6fbaf3d9 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -43,18 +43,16 @@ def freq(wavelength, c=None): def test_single_return_value(self): - @expects('Hz') - def period(freq): - return 1 / freq + @expects('kg', 'm / s^2', return_units='newtons') + def second_law(m, a): + return m * a - f = pint.Quantity(10, units='Hz') + m_q = pint.Quantity(0.1, units='tons') + a_q = pint.Quantity(10, units='feet / second^2') + assert second_law(m_q, a_q).pint.units == pint.Unit('newtons') - # test conversion - T = period(f) - assert f.units == 'seconds' - - # test wrong dimensions for conversion - ... + m_da + a_da @pytest.mark.xfail def test_multiple_return_values(self): From aad793654eec1c8a905c4ec478d701060325d5c4 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 28 Oct 2021 15:33:25 -0400 Subject: [PATCH 04/50] upgrade check then convert function to optionally take magnitude --- pint_xarray/checking.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 0f694201..b943d53b 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -108,11 +108,19 @@ def _unit_checking_wrapper(*args, **kwargs): return _expects_decorator -def _check_then_convert_to(obj, units): +def _check_then_convert_to(obj, units, magnitude): if isinstance(obj, Quantity): - return obj.to(units) + converted = obj.to(units) + if magnitude: + return converted.magnitude + else: + return converted elif isinstance(obj, DataArray): - return obj.pint.to(units) + converted = obj.pint.to(units) + if magnitude: + return converted.pint.magnitude + else: + return converted else: raise TypeError("Can only expect units for arguments of type xarray.DataArray or pint.Quantity," f"not {type(obj)}") From e354f4e013cd3ab3c9a1e7c29961d64a08d671f4 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 29 Nov 2021 15:34:15 -0500 Subject: [PATCH 05/50] removed magnitude option --- pint_xarray/checking.py | 48 ++++++++++++++++-------------- pint_xarray/tests/test_checking.py | 42 +++++++++++--------------- 2 files changed, 43 insertions(+), 47 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index b943d53b..bb0f3df3 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -3,7 +3,7 @@ from pint import Quantity from xarray import DataArray -from .accessors import PintDataArrayAccessor +from .accessors import PintDataArrayAccessor # noqa def expects(*args_units, return_units=None, **kwargs_units): @@ -19,45 +19,52 @@ def expects(*args_units, return_units=None, **kwargs_units): Parameters ---------- func: function - Decorated function, which accepts zero or more xarray.DataArrays or pint.Quantitys as inputs, - and may optionally return one or more xarray.DataArrays or pint.Quantitys. + Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, + and may optionally return one or more xarray.DataArrays or numpy-like arrays. args_units : Union[str, pint.Unit, None] - Unit to expect for each positional argument given to func. + Units to expect for each positional argument given to func. + + The decorator will check that arguments passed to the decorated function possess these specific units, or will + attempt to convert the argument to these units. A value of None indicates not to check that argument for units (suitable for flags and other non-data arguments). return_units : Union[Union[str, pint.Unit, None, False], Sequence[Union[str, pint.Unit, None]], Optional The expected units of the returned value(s), either as a single unit or as an iterable of units. + The decorator will check that results returned from the decorated function possess these specific units, or will + attempt to convert the results to these units. + A value of None indicates not to check that return value for units (suitable for flags and other non-data arguments). Passing False means that no return value is expected from the function at all, and an error will be raised if a return value is found. kwargs_units : Dict[str, Union[str, pint.Unit, None]], Optional Unit to expect for each keyword argument given to func. + The decorator will check that arguments passed to the decorated function possess these specific units, or will + attempt to convert the argument to these units. + A value of None indicates not to check that argument for units (suitable for flags and other non-data arguments). Returns ------- - return_values + return_values : Any Return values of the wrapped function, either a single value or a tuple of values. Raises ------ TypeError - If an argument or return value has a specified unit, but is not an xarray.DataArray. - + If an argument or return value has a specified unit, but is not an xarray.DataArray or pint.Quantity. Examples -------- Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean). - >>> @expects('deg C') + >>> @expects("deg C") ... def above_freezing(temp): ... return temp > 0 - ... TODO: example where we check units of an optional weighted kwarg @@ -67,7 +74,6 @@ def expects(*args_units, return_units=None, **kwargs_units): # TODO: Check number of arguments line up def _expects_decorator(func): - @functools.wraps(func) def _unit_checking_wrapper(*args, **kwargs): @@ -93,7 +99,9 @@ def _unit_checking_wrapper(*args, **kwargs): converted_results = [] for return_unit, return_value in zip(return_units, results): - converted_result = _check_then_convert_to(return_value, return_unit) + converted_result = _check_then_convert_to( + return_value, return_unit + ) converted_results.append(converted_result) return tuple(converted_results) @@ -108,19 +116,15 @@ def _unit_checking_wrapper(*args, **kwargs): return _expects_decorator -def _check_then_convert_to(obj, units, magnitude): +def _check_then_convert_to(obj, units): if isinstance(obj, Quantity): converted = obj.to(units) - if magnitude: - return converted.magnitude - else: - return converted + return converted.magnitude elif isinstance(obj, DataArray): converted = obj.pint.to(units) - if magnitude: - return converted.pint.magnitude - else: - return converted + return converted.pint.magnitude else: - raise TypeError("Can only expect units for arguments of type xarray.DataArray or pint.Quantity," - f"not {type(obj)}") + raise TypeError( + "Can only expect units for arguments of type xarray.DataArray or pint.Quantity," + f"not {type(obj)}" + ) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 6fbaf3d9..86e68f5f 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -1,7 +1,6 @@ -import pytest import pint +import pytest import xarray as xr - from pint import UnitRegistry from ..checking import expects @@ -11,30 +10,24 @@ class TestExpects: def test_single_arg(self): + @expects("degC") + def above_freezing(temp): + return temp > 0 - @expects('degC') - def above_freezing(temp : pint.Quantity): - return temp.magnitude > 0 - - f_q = pint.Quantity(20, units='degF') - assert above_freezing(f_q) == False + f_q = pint.Quantity(20, units="degF") + assert not above_freezing(f_q) - c_q = pint.Quantity(-2, units='degC') - assert above_freezing(c_q) == False + c_q = pint.Quantity(-2, units="degC") + assert not above_freezing(c_q) - @expects('degC') - def above_freezing(temp : xr.DataArray): - return temp.pint.magnitude > 0 + f_da = xr.DataArray(20).pint.quantify(units="degF") + assert not above_freezing(f_da) - f_da = xr.DataArray(20).pint.quantify(units='degF') - assert above_freezing(f_da) == False - - c_da = xr.DataArray(-2).pint.quantify(units='degC') - assert above_freezing(c_da) == False + c_da = xr.DataArray(-2).pint.quantify(units="degC") + assert not above_freezing(c_da) def test_single_kwarg(self): - - @expects('meters', c='meters / second') + @expects("meters", c="meters / second") def freq(wavelength, c=None): if c is None: c = ureg.speed_of_light @@ -42,14 +35,13 @@ def freq(wavelength, c=None): return c / wavelength def test_single_return_value(self): - - @expects('kg', 'm / s^2', return_units='newtons') + @expects("kg", "m / s^2", return_units="newtons") def second_law(m, a): return m * a - m_q = pint.Quantity(0.1, units='tons') - a_q = pint.Quantity(10, units='feet / second^2') - assert second_law(m_q, a_q).pint.units == pint.Unit('newtons') + m_q = pint.Quantity(0.1, units="tons") + a_q = pint.Quantity(10, units="feet / second^2") + assert second_law(m_q, a_q).pint.units == pint.Unit("newtons") m_da a_da From 7727d8e564b2c7f444b8c3aeee15281faeefc6db Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 29 Nov 2021 16:30:09 -0500 Subject: [PATCH 06/50] works for single return value --- pint_xarray/checking.py | 74 +++++++++++++++++++----------- pint_xarray/tests/test_checking.py | 9 ++-- 2 files changed, 54 insertions(+), 29 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index bb0f3df3..f18c5264 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -24,25 +24,25 @@ def expects(*args_units, return_units=None, **kwargs_units): args_units : Union[str, pint.Unit, None] Units to expect for each positional argument given to func. - The decorator will check that arguments passed to the decorated function possess these specific units, or will - attempt to convert the argument to these units. + The decorator will first check that arguments passed to the decorated function possess these specific units + (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude + to the wrapped function. A value of None indicates not to check that argument for units (suitable for flags and other non-data arguments). return_units : Union[Union[str, pint.Unit, None, False], Sequence[Union[str, pint.Unit, None]], Optional - The expected units of the returned value(s), either as a single unit or as an iterable of units. + The expected units of the returned value(s), either as a single unit or as an iterable of units. The decorator + will attach these units to the output. - The decorator will check that results returned from the decorated function possess these specific units, or will - attempt to convert the results to these units. - - A value of None indicates not to check that return value for units (suitable for flags and other + A value of None indicates not to attach any units to that return value (suitable for flags and other non-data arguments). Passing False means that no return value is expected from the function at all, and an error will be raised if a return value is found. kwargs_units : Dict[str, Union[str, pint.Unit, None]], Optional Unit to expect for each keyword argument given to func. - The decorator will check that arguments passed to the decorated function possess these specific units, or will - attempt to convert the argument to these units. + The decorator will first check that arguments passed to the decorated function possess these specific units + (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude + to the wrapped function. A value of None indicates not to check that argument for units (suitable for flags and other non-data arguments). @@ -50,7 +50,8 @@ def expects(*args_units, return_units=None, **kwargs_units): Returns ------- return_values : Any - Return values of the wrapped function, either a single value or a tuple of values. + Return values of the wrapped function, either a single value or a tuple of values. These will have units + according to return_units. Raises ------ @@ -79,52 +80,73 @@ def _unit_checking_wrapper(*args, **kwargs): converted_args = [] for arg, arg_unit in zip(args, args_units): - converted_arg = _check_then_convert_to(arg, arg_unit) + converted_arg = _check_or_convert_to_then_strip(arg, arg_unit) converted_args.append(converted_arg) converted_kwargs = {} for key, val in kwargs.items(): kwarg_unit = kwargs_units[key] - converted_kwargs[key] = _check_then_convert_to(val, kwarg_unit) + converted_kwargs[key] = _check_or_convert_to_then_strip(val, kwarg_unit) results = func(*converted_args, **converted_kwargs) if results is not None: if return_units is False: - raise ValueError("Did not expect function to return anything") + raise ValueError( + f"Did not expect function to return anything, but function returned {results}" + ) elif return_units is not None: # TODO check something was actually returned - # TODO check same number of things were returned as expected - # TODO handle single return value vs tuple of return values - - converted_results = [] - for return_unit, return_value in zip(return_units, results): - converted_result = _check_then_convert_to( - return_value, return_unit - ) - converted_results.append(converted_result) - return tuple(converted_results) + # TODO handle single return value vs tuple of return values + if type(results) == tuple: + + # TODO check same number of things were returned as expected + + converted_results = [] + for return_unit, return_value in zip(return_units, results): + converted_result = _attach_units(return_value, return_unit) + converted_results.append(converted_result) + return tuple(converted_results) + else: + converted_result = _attach_units(results, return_units) + return converted_result else: + # ignore types and units of return values return results else: if return_units: - raise ValueError("Expected function to return something") + raise ValueError( + "Expected function to return something, but function returned None" + ) return _unit_checking_wrapper return _expects_decorator -def _check_then_convert_to(obj, units): +def _check_or_convert_to_then_strip(obj, units): + """ + Checks the object is of a valid type (Quantity or DataArray), then attempts to convert it to the specified units, + then strips the units from it. + """ + if isinstance(obj, Quantity): converted = obj.to(units) return converted.magnitude elif isinstance(obj, DataArray): converted = obj.pint.to(units) - return converted.pint.magnitude + return converted.pint.dequantify() else: raise TypeError( "Can only expect units for arguments of type xarray.DataArray or pint.Quantity," f"not {type(obj)}" ) + + +def _attach_units(obj, units): + """Attaches units, but can also create pint.Quantity objects from numpy scalars""" + if isinstance(obj, DataArray): + return obj.pint.quantify(units) + else: + return Quantity(obj, units=units) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 86e68f5f..f4ce8387 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -41,10 +41,13 @@ def second_law(m, a): m_q = pint.Quantity(0.1, units="tons") a_q = pint.Quantity(10, units="feet / second^2") - assert second_law(m_q, a_q).pint.units == pint.Unit("newtons") + result_q = second_law(m_q, a_q) + assert result_q.units == pint.Unit("newtons") - m_da - a_da + m_da = xr.DataArray(0.1).pint.quantify(units="tons") + a_da = xr.DataArray(10).pint.quantify(units="feet / second^2") + result_da = second_law(m_da, a_da) + assert result_da.pint.units == pint.Unit("newtons") @pytest.mark.xfail def test_multiple_return_values(self): From 13797790fa9cbd30894ee969a8c3f014b84a118e Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 13:31:26 -0500 Subject: [PATCH 07/50] works for single kwarg --- pint_xarray/checking.py | 6 ++++-- pint_xarray/tests/test_checking.py | 24 ++++++++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index f18c5264..4fb52190 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -8,9 +8,11 @@ def expects(*args_units, return_units=None, **kwargs_units): """ - Decorator which checks the inputs and outputs of the decorated function have certain units. + Decorator which ensures the inputs and outputs of the decorated function have certain units. - Arguments + Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and + then stripped of their units before being passed into the undecorated function. Therefore the undecorated function + should expect unquantified DataArrays or numpy-like arrays, but with the values expressed in specific units. Note that the coordinates of input DataArrays are not checked, only the data. So if your decorated function uses coordinates and you wish to check their units, diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index f4ce8387..e685217c 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -27,13 +27,27 @@ def above_freezing(temp): assert not above_freezing(c_da) def test_single_kwarg(self): - @expects("meters", c="meters / second") + @expects("meters", c="meters / second", return_units="Hz") def freq(wavelength, c=None): if c is None: - c = ureg.speed_of_light + c = (1 * ureg.speed_of_light).to("meters / second").magnitude return c / wavelength + w_q = pint.Quantity(0.02, units="inches") + c_q = pint.Quantity(1e6, units="feet / second") + f_q = freq(w_q, c=c_q) + assert f_q.units == pint.Unit("hertz") + f_q = freq(w_q) + assert f_q.units == pint.Unit("hertz") + + w_da = xr.DataArray(0.02).pint.quantify(units="inches") + c_da = xr.DataArray(1e6).pint.quantify(units="feet / second") + f_da = freq(w_da, c=c_da) + assert f_da.pint.units == pint.Unit("hertz") + f_da = freq(w_da) + assert f_da.pint.units == pint.Unit("hertz") + def test_single_return_value(self): @expects("kg", "m / s^2", return_units="newtons") def second_law(m, a): @@ -41,13 +55,15 @@ def second_law(m, a): m_q = pint.Quantity(0.1, units="tons") a_q = pint.Quantity(10, units="feet / second^2") + expected_q = (m_q * a_q).to("newtons") result_q = second_law(m_q, a_q) - assert result_q.units == pint.Unit("newtons") + assert result_q == expected_q m_da = xr.DataArray(0.1).pint.quantify(units="tons") a_da = xr.DataArray(10).pint.quantify(units="feet / second^2") + expected_da = (m_da * a_da).pint.to("newtons") result_da = second_law(m_da, a_da) - assert result_da.pint.units == pint.Unit("newtons") + assert result_da == expected_da @pytest.mark.xfail def test_multiple_return_values(self): From 77f5d02eab9dd6f7e90d280f24513ec04486dab8 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 13:44:22 -0500 Subject: [PATCH 08/50] works for multiple return values --- pint_xarray/tests/test_checking.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index e685217c..c2c656c2 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -65,9 +65,28 @@ def second_law(m, a): result_da = second_law(m_da, a_da) assert result_da == expected_da - @pytest.mark.xfail def test_multiple_return_values(self): - raise NotImplementedError + @expects("kg", "m / s", return_units=["J", "newton seconds"]) + def energy_and_momentum(m, v): + ke = 0.5 * m * v**2 + p = m * v + return ke, p + + m = pint.Quantity(0.1, units="tons") + v = pint.Quantity(10, units="feet / second") + expected_ke = (0.5 * m * v**2).to("J") + expected_p = (m * v).to("newton seconds") + result_ke, result_p = energy_and_momentum(m, v) + assert result_ke.units == expected_ke.units + assert result_p.units == expected_p.units + + m = xr.DataArray(0.1).pint.quantify(units="tons") + a = xr.DataArray(10).pint.quantify(units="feet / second") + expected_ke = (0.5 * m * v**2).pint.to("J") + expected_p = (m * v).pint.to("newton seconds") + result_ke, result_p = energy_and_momentum(m, v) + assert result_ke.pint.units == expected_ke.pint.units + assert result_p.pint.units == expected_p.pint.units @pytest.mark.xfail def test_mixed_args_kwargs_return_values(self): From 71f420067c2de3715b7305cc1c40f4d95e81c2cc Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 14:18:16 -0500 Subject: [PATCH 09/50] allow passing through arguments unchecked --- pint_xarray/checking.py | 26 +++++++++++++++----------- pint_xarray/tests/test_checking.py | 16 ++++++++++++---- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 4fb52190..f2725e97 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -8,7 +8,7 @@ def expects(*args_units, return_units=None, **kwargs_units): """ - Decorator which ensures the inputs and outputs of the decorated function have certain units. + Decorator which ensures the inputs and outputs of the decorated function are expressed in the expected units. Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and then stripped of their units before being passed into the undecorated function. Therefore the undecorated function @@ -133,17 +133,21 @@ def _check_or_convert_to_then_strip(obj, units): then strips the units from it. """ - if isinstance(obj, Quantity): - converted = obj.to(units) - return converted.magnitude - elif isinstance(obj, DataArray): - converted = obj.pint.to(units) - return converted.pint.dequantify() + if units is None: + # allow for passing through non-numerical arguments + return obj else: - raise TypeError( - "Can only expect units for arguments of type xarray.DataArray or pint.Quantity," - f"not {type(obj)}" - ) + if isinstance(obj, Quantity): + converted = obj.to(units) + return converted.magnitude + elif isinstance(obj, DataArray): + converted = obj.pint.to(units) + return converted.pint.dequantify() + else: + raise TypeError( + "Can only expect units for arguments of type xarray.DataArray or pint.Quantity," + f"not {type(obj)}" + ) def _attach_units(obj, units): diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index c2c656c2..f8ece969 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -68,26 +68,34 @@ def second_law(m, a): def test_multiple_return_values(self): @expects("kg", "m / s", return_units=["J", "newton seconds"]) def energy_and_momentum(m, v): - ke = 0.5 * m * v**2 + ke = 0.5 * m * v ** 2 p = m * v return ke, p m = pint.Quantity(0.1, units="tons") v = pint.Quantity(10, units="feet / second") - expected_ke = (0.5 * m * v**2).to("J") + expected_ke = (0.5 * m * v ** 2).to("J") expected_p = (m * v).to("newton seconds") result_ke, result_p = energy_and_momentum(m, v) assert result_ke.units == expected_ke.units assert result_p.units == expected_p.units m = xr.DataArray(0.1).pint.quantify(units="tons") - a = xr.DataArray(10).pint.quantify(units="feet / second") - expected_ke = (0.5 * m * v**2).pint.to("J") + v = xr.DataArray(10).pint.quantify(units="feet / second") + expected_ke = (0.5 * m * v ** 2).pint.to("J") expected_p = (m * v).pint.to("newton seconds") result_ke, result_p = energy_and_momentum(m, v) assert result_ke.pint.units == expected_ke.pint.units assert result_p.pint.units == expected_p.pint.units + def test_dont_check_arg_units(self): + @expects("seconds", None, return_units=None) + def finite_difference(a, type): + return ... + + t = pint.Quantity(0.1, units="seconds") + finite_difference(t, "centered") + @pytest.mark.xfail def test_mixed_args_kwargs_return_values(self): raise NotImplementedError From a710741277ac36c530295cbf6d1fa7eb72f89069 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 14:41:40 -0500 Subject: [PATCH 10/50] check types of units --- pint_xarray/checking.py | 21 ++++++++++++++++----- pint_xarray/tests/test_checking.py | 12 +++++++++--- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index f2725e97..c756a9ea 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -1,6 +1,6 @@ import functools -from pint import Quantity +from pint import Quantity, Unit from xarray import DataArray from .accessors import PintDataArrayAccessor # noqa @@ -32,9 +32,9 @@ def expects(*args_units, return_units=None, **kwargs_units): A value of None indicates not to check that argument for units (suitable for flags and other non-data arguments). - return_units : Union[Union[str, pint.Unit, None, False], Sequence[Union[str, pint.Unit, None]], Optional - The expected units of the returned value(s), either as a single unit or as an iterable of units. The decorator - will attach these units to the output. + return_units : Union[Union[str, pint.Unit, None, False], List[Union[str, pint.Unit, None]], Optional + The expected units of the returned value(s), either as a single unit or as a list of units. The decorator + will attach these units to the variables returned from the function. A value of None indicates not to attach any units to that return value (suitable for flags and other non-data arguments). Passing False means that no return value is expected from the function at all, @@ -73,7 +73,18 @@ def expects(*args_units, return_units=None, **kwargs_units): TODO: example where we check units of an optional weighted kwarg """ - # TODO: Check args_units, kwargs_units, and return_units types + # Check types of args_units, kwargs_units, and return_units + all_units = list(args_units) + list(kwargs_units.values()) + if isinstance(return_units, list): + all_units = all_units + return_units + elif return_units: + all_units = all_units + [return_units] + for a in all_units: + if not isinstance(a, (Unit, str)) and a is not None: + raise TypeError( + f"{a} is not a valid type for a unit, it is of type {type(a)}" + ) + # TODO: Check number of arguments line up def _expects_decorator(func): diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index f8ece969..ba1e406a 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -100,9 +100,15 @@ def finite_difference(a, type): def test_mixed_args_kwargs_return_values(self): raise NotImplementedError - @pytest.mark.xfail - def test_invalid_input_types(self): - raise NotImplementedError + @pytest.mark.parametrize( + "arg_units, return_units", [(True, None), ("seconds", 6), ("seconds", [6])] + ) + def test_invalid_unit_types(self, arg_units, return_units): + with pytest.raises(TypeError): + + @expects(arg_units, return_units=return_units) + def freq(period): + ... @pytest.mark.xfail def test_invalid_return_types(self): From 497e97f864286d62e664bb7b3472044c5100f4ab Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 15:04:52 -0500 Subject: [PATCH 11/50] remove uneeded option to specify a lack of return value --- pint_xarray/checking.py | 25 ++++++++++--------------- pint_xarray/tests/test_checking.py | 4 ---- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index c756a9ea..37951d40 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -30,15 +30,14 @@ def expects(*args_units, return_units=None, **kwargs_units): (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude to the wrapped function. - A value of None indicates not to check that argument for units (suitable for flags - and other non-data arguments). - return_units : Union[Union[str, pint.Unit, None, False], List[Union[str, pint.Unit, None]], Optional + A value of None indicates not to check that argument for units (suitable for flags and other non-data + arguments). + return_units : Union[Union[str, pint.Unit, None], List[Union[str, pint.Unit, None]], Optional The expected units of the returned value(s), either as a single unit or as a list of units. The decorator will attach these units to the variables returned from the function. A value of None indicates not to attach any units to that return value (suitable for flags and other - non-data arguments). Passing False means that no return value is expected from the function at all, - and an error will be raised if a return value is found. + non-data results). kwargs_units : Dict[str, Union[str, pint.Unit, None]], Optional Unit to expect for each keyword argument given to func. @@ -46,8 +45,8 @@ def expects(*args_units, return_units=None, **kwargs_units): (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude to the wrapped function. - A value of None indicates not to check that argument for units (suitable for flags - and other non-data arguments). + A value of None indicates not to check that argument for units (suitable for flags and other non-data + arguments). Returns ------- @@ -104,11 +103,10 @@ def _unit_checking_wrapper(*args, **kwargs): results = func(*converted_args, **converted_kwargs) if results is not None: - if return_units is False: - raise ValueError( - f"Did not expect function to return anything, but function returned {results}" - ) - elif return_units is not None: + if return_units is None: + # ignore types and units of return values + return results + else: # TODO check something was actually returned # TODO handle single return value vs tuple of return values @@ -124,9 +122,6 @@ def _unit_checking_wrapper(*args, **kwargs): else: converted_result = _attach_units(results, return_units) return converted_result - else: - # ignore types and units of return values - return results else: if return_units: raise ValueError( diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index ba1e406a..da71cde0 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -126,10 +126,6 @@ def test_wrong_number_of_args(self): def test_nonexistent_kwarg(self): raise NotImplementedError - @pytest.mark.xfail - def test_unexpected_return_value(self): - raise NotImplementedError - @pytest.mark.xfail def test_expected_return_value(self): raise NotImplementedError From 00219bc28af27bb0db7ce74fa78613d6d1a8022e Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 17:49:55 -0500 Subject: [PATCH 12/50] check number of inputs and return values --- pint_xarray/checking.py | 67 +++++++++++++++++++----------- pint_xarray/tests/test_checking.py | 40 +++++++++++------- 2 files changed, 69 insertions(+), 38 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 37951d40..0647d62d 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -51,13 +51,15 @@ def expects(*args_units, return_units=None, **kwargs_units): Returns ------- return_values : Any - Return values of the wrapped function, either a single value or a tuple of values. These will have units + Return values of the wrapped function, either a single value or a tuple of values. These will be given units according to return_units. Raises ------ TypeError If an argument or return value has a specified unit, but is not an xarray.DataArray or pint.Quantity. + Also thrown if any of the units are not a valid type, or if the number of arguments or return values does not + match the number of units specified. Examples -------- @@ -72,6 +74,8 @@ def expects(*args_units, return_units=None, **kwargs_units): TODO: example where we check units of an optional weighted kwarg """ + # TODO generalise to allow for dictionaries of units for DataArray coordinates / Datasets + # Check types of args_units, kwargs_units, and return_units all_units = list(args_units) + list(kwargs_units.values()) if isinstance(return_units, list): @@ -84,12 +88,20 @@ def expects(*args_units, return_units=None, **kwargs_units): f"{a} is not a valid type for a unit, it is of type {type(a)}" ) - # TODO: Check number of arguments line up - def _expects_decorator(func): @functools.wraps(func) def _unit_checking_wrapper(*args, **kwargs): + # without this we get an UnboundLocalError but I have no idea why + # see https://stackoverflow.com/questions/5630409/ + nonlocal return_units + + # check same number of arguments were passed as expected + if len(args) != len(args_units): + raise TypeError( + f"{len(args)} arguments were passed, but {len(args_units)} arguments were expected" + ) + converted_args = [] for arg, arg_unit in zip(args, args_units): converted_arg = _check_or_convert_to_then_strip(arg, arg_unit) @@ -97,36 +109,43 @@ def _unit_checking_wrapper(*args, **kwargs): converted_kwargs = {} for key, val in kwargs.items(): - kwarg_unit = kwargs_units[key] + kwarg_unit = kwargs_units.get(key, None) converted_kwargs[key] = _check_or_convert_to_then_strip(val, kwarg_unit) results = func(*converted_args, **converted_kwargs) - if results is not None: + if results is None: + if return_units: + raise TypeError( + "Expected function to return something, but function returned None" + ) + else: if return_units is None: # ignore types and units of return values return results else: - # TODO check something was actually returned - - # TODO handle single return value vs tuple of return values - if type(results) == tuple: - - # TODO check same number of things were returned as expected - - converted_results = [] - for return_unit, return_value in zip(return_units, results): - converted_result = _attach_units(return_value, return_unit) - converted_results.append(converted_result) - return tuple(converted_results) + # handle case of function returning only one result by promoting to 1-element tuple + if not isinstance(results, tuple): + results = (results,) + if not isinstance(return_units, (tuple, list)): + return_units = (return_units,) + + # check same number of things were returned as expected + if len(results) != len(return_units): + raise TypeError( + f"{len(results)} return values were received, but {len(return_units)} " + "return values were expected" + ) + + converted_results = [] + for return_unit, return_value in zip(return_units, results): + converted_result = _attach_units(return_value, return_unit) + converted_results.append(converted_result) + + if len(converted_results) == 1: + return tuple(converted_results)[0] else: - converted_result = _attach_units(results, return_units) - return converted_result - else: - if return_units: - raise ValueError( - "Expected function to return something, but function returned None" - ) + return tuple(converted_results) return _unit_checking_wrapper diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index da71cde0..639b861a 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -96,10 +96,6 @@ def finite_difference(a, type): t = pint.Quantity(0.1, units="seconds") finite_difference(t, "centered") - @pytest.mark.xfail - def test_mixed_args_kwargs_return_values(self): - raise NotImplementedError - @pytest.mark.parametrize( "arg_units, return_units", [(True, None), ("seconds", 6), ("seconds", [6])] ) @@ -118,18 +114,34 @@ def test_invalid_return_types(self): def test_unquantified_arrays(self): raise NotImplementedError - @pytest.mark.xfail def test_wrong_number_of_args(self): - raise NotImplementedError + @expects("kg", return_units="newtons") + def second_law(m, a): + return m * a - @pytest.mark.xfail - def test_nonexistent_kwarg(self): - raise NotImplementedError + m_q = pint.Quantity(0.1, units="tons") + a_q = pint.Quantity(10, units="feet / second^2") - @pytest.mark.xfail - def test_expected_return_value(self): - raise NotImplementedError + with pytest.raises(TypeError, match="1 arguments were expected"): + second_law(m_q, a_q) - @pytest.mark.xfail def test_wrong_number_of_return_values(self): - raise NotImplementedError + @expects("kg", "m / s^2", return_units=["newtons", "joules"]) + def second_law(m, a): + return m * a + + m_q = pint.Quantity(0.1, units="tons") + a_q = pint.Quantity(10, units="feet / second^2") + + with pytest.raises(TypeError, match="2 return values were expected"): + second_law(m_q, a_q) + + def test_expected_return_value(self): + @expects("seconds", return_units="Hz") + def freq(period): + return None + + p = pint.Quantity(2, units="seconds") + + with pytest.raises(TypeError, match="function returned None"): + freq(p) From 9e92f21618211da7d059b8faa440f0a96b59a729 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 18:10:39 -0500 Subject: [PATCH 13/50] removed nonlocal keyword --- pint_xarray/checking.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 0647d62d..3c1e68d5 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -92,10 +92,6 @@ def _expects_decorator(func): @functools.wraps(func) def _unit_checking_wrapper(*args, **kwargs): - # without this we get an UnboundLocalError but I have no idea why - # see https://stackoverflow.com/questions/5630409/ - nonlocal return_units - # check same number of arguments were passed as expected if len(args) != len(args_units): raise TypeError( @@ -127,18 +123,21 @@ def _unit_checking_wrapper(*args, **kwargs): # handle case of function returning only one result by promoting to 1-element tuple if not isinstance(results, tuple): results = (results,) - if not isinstance(return_units, (tuple, list)): - return_units = (return_units,) + if isinstance(return_units, (tuple, list)): + # avoid mutating return_units because that would change the variables' scope + return_units_list = return_units + else: + return_units_list = [return_units] # check same number of things were returned as expected - if len(results) != len(return_units): + if len(results) != len(return_units_list): raise TypeError( - f"{len(results)} return values were received, but {len(return_units)} " + f"{len(results)} return values were received, but {len(return_units_list)} " "return values were expected" ) converted_results = [] - for return_unit, return_value in zip(return_units, results): + for return_unit, return_value in zip(return_units_list, results): converted_result = _attach_units(return_value, return_unit) converted_results.append(converted_result) From 86f7e58d485bb35a79e597bcd59985164204921b Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 18:24:04 -0500 Subject: [PATCH 14/50] generalised to handle specifying dicts of units --- pint_xarray/checking.py | 20 ++++++++++++-------- pint_xarray/tests/test_checking.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 3c1e68d5..02e17607 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -1,7 +1,7 @@ import functools from pint import Quantity, Unit -from xarray import DataArray +from xarray import DataArray, Dataset from .accessors import PintDataArrayAccessor # noqa @@ -74,8 +74,6 @@ def expects(*args_units, return_units=None, **kwargs_units): TODO: example where we check units of an optional weighted kwarg """ - # TODO generalise to allow for dictionaries of units for DataArray coordinates / Datasets - # Check types of args_units, kwargs_units, and return_units all_units = list(args_units) + list(kwargs_units.values()) if isinstance(return_units, list): @@ -83,10 +81,11 @@ def expects(*args_units, return_units=None, **kwargs_units): elif return_units: all_units = all_units + [return_units] for a in all_units: - if not isinstance(a, (Unit, str)) and a is not None: - raise TypeError( - f"{a} is not a valid type for a unit, it is of type {type(a)}" - ) + if isinstance(a, dict): + for u in a.values(): + _check_valid_unit_type(u) + else: + _check_valid_unit_type(a) def _expects_decorator(func): @functools.wraps(func) @@ -151,6 +150,11 @@ def _unit_checking_wrapper(*args, **kwargs): return _expects_decorator +def _check_valid_unit_type(a): + if not isinstance(a, (Unit, str)) and a is not None: + raise TypeError(f"{a} is not a valid type for a unit, it is of type {type(a)}") + + def _check_or_convert_to_then_strip(obj, units): """ Checks the object is of a valid type (Quantity or DataArray), then attempts to convert it to the specified units, @@ -164,7 +168,7 @@ def _check_or_convert_to_then_strip(obj, units): if isinstance(obj, Quantity): converted = obj.to(units) return converted.magnitude - elif isinstance(obj, DataArray): + elif isinstance(obj, (DataArray, Dataset)): converted = obj.pint.to(units) return converted.pint.dequantify() else: diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 639b861a..4de43938 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -145,3 +145,16 @@ def freq(period): with pytest.raises(TypeError, match="function returned None"): freq(p) + + def test_unit_dict(self): + @expects({"m": "kg", "a": "m / s^2"}, return_units="newtons") + def second_law(ds): + return ds["m"] * ds["a"] + + m_da = xr.DataArray(0.1).pint.quantify(units="tons") + a_da = xr.DataArray(10).pint.quantify(units="feet / second^2") + ds = xr.Dataset({"m": m_da, "a": a_da}) + + expected_da = (m_da * a_da).pint.to("newtons") + result_da = second_law(ds) + assert result_da == expected_da From 2141c6c829aba0830b6819cdec7487bd878b25ad Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 30 Nov 2021 18:24:48 -0500 Subject: [PATCH 15/50] type hint for func Co-authored-by: keewis --- pint_xarray/checking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 3c1e68d5..813d603e 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -20,7 +20,7 @@ def expects(*args_units, return_units=None, **kwargs_units): Parameters ---------- - func: function + func : callable Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, and may optionally return one or more xarray.DataArrays or numpy-like arrays. args_units : Union[str, pint.Unit, None] From a94a6ae02ae3c727aea5d9d90ef059ecdf761fcd Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 30 Nov 2021 18:25:17 -0500 Subject: [PATCH 16/50] type hint for args_units Co-authored-by: keewis --- pint_xarray/checking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 813d603e..29bc0d5e 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -23,7 +23,7 @@ def expects(*args_units, return_units=None, **kwargs_units): func : callable Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, and may optionally return one or more xarray.DataArrays or numpy-like arrays. - args_units : Union[str, pint.Unit, None] + *args_units : None or unit-like or mapping of None or unit-like, optional Units to expect for each positional argument given to func. The decorator will first check that arguments passed to the decorated function possess these specific units From 7103483b020e759f5d207bd95f79e41609d9389a Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 18:28:38 -0500 Subject: [PATCH 17/50] numpy-style type hints for all arguments --- pint_xarray/checking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 9528450a..f98c7113 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -23,7 +23,7 @@ def expects(*args_units, return_units=None, **kwargs_units): func : callable Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, and may optionally return one or more xarray.DataArrays or numpy-like arrays. - *args_units : None or unit-like or mapping of None or unit-like, optional + *args_units : unit-like or mapping of unit-like or None, optional Units to expect for each positional argument given to func. The decorator will first check that arguments passed to the decorated function possess these specific units @@ -32,13 +32,13 @@ def expects(*args_units, return_units=None, **kwargs_units): A value of None indicates not to check that argument for units (suitable for flags and other non-data arguments). - return_units : Union[Union[str, pint.Unit, None], List[Union[str, pint.Unit, None]], Optional + return_units : unit-like or mapping of unit-like or list of unit-like or mapping of unit-like or None, optional The expected units of the returned value(s), either as a single unit or as a list of units. The decorator will attach these units to the variables returned from the function. A value of None indicates not to attach any units to that return value (suitable for flags and other non-data results). - kwargs_units : Dict[str, Union[str, pint.Unit, None]], Optional + kwargs_units : mapping of unit-like or None, Optional Unit to expect for each keyword argument given to func. The decorator will first check that arguments passed to the decorated function possess these specific units From 59ddf86b793f65d8cc2989493beabaa46f8a0737 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 18:34:17 -0500 Subject: [PATCH 18/50] whats new --- docs/whats-new.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/whats-new.rst b/docs/whats-new.rst index 256baffe..bcca5b52 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -3,6 +3,12 @@ What's new ========== +0.2.2 (unreleased) +------------------ + +- Added the :py:func:`pint_xarray.expects` decorator (:pull:`143`). + By `Tom Nicholas `_ + 0.2.1 (26 Jul 2021) ------------------- - allow special "no unit" values in :py:meth:`Dataset.pint.quantify` and From a5a2493df0c911c9a1211793059f68be34de3c97 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 30 Nov 2021 18:36:13 -0500 Subject: [PATCH 19/50] add to API docs --- docs/api.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index bb7cb2f9..d172555d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -65,6 +65,14 @@ DataArray xarray.DataArray.pint.bfill xarray.DataArray.pint.interpolate_na +Checking +------- + +.. autosummary:: + :toctree: generated/ + + pint_xarray.expects + Testing ------- From e5e84fb8daf292a0eccc6f0bcbb9364d700f48ef Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 14 Dec 2021 14:20:59 -0500 Subject: [PATCH 20/50] use always_iterable --- pint_xarray/checking.py | 49 ++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index f98c7113..057933f1 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -120,25 +120,19 @@ def _unit_checking_wrapper(*args, **kwargs): return results else: # handle case of function returning only one result by promoting to 1-element tuple - if not isinstance(results, tuple): - results = (results,) - if isinstance(return_units, (tuple, list)): - # avoid mutating return_units because that would change the variables' scope - return_units_list = return_units - else: - return_units_list = [return_units] + return_units_iterable = tuple(always_iterable(return_units)) + results_iterable = tuple(always_iterable(results)) # check same number of things were returned as expected - if len(results) != len(return_units_list): + if len(results_iterable) != len(return_units_iterable): raise TypeError( - f"{len(results)} return values were received, but {len(return_units_list)} " + f"{len(results_iterable)} return values were received, but {len(return_units_iterable)} " "return values were expected" ) - converted_results = [] - for return_unit, return_value in zip(return_units_list, results): - converted_result = _attach_units(return_value, return_unit) - converted_results.append(converted_result) + converted_results = _attach_multiple_units( + results_iterable, return_units_iterable + ) if len(converted_results) == 1: return tuple(converted_results)[0] @@ -184,3 +178,32 @@ def _attach_units(obj, units): return obj.pint.quantify(units) else: return Quantity(obj, units=units) + + +def _attach_multiple_units(objects, units): + """Attaches list of units to list of objects elementwise""" + converted_objects = [_attach_units(obj, unit) for obj, unit in zip(objects, units)] + return converted_objects + + +def always_iterable(obj, base_type=(str, bytes)): + """ + If *obj* is iterable, return an iterator over its items, + If *obj* is not iterable, return a one-item iterable containing *obj*, + If *obj* is ``None``, return an empty iterable. + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + Copied from more_itertools. + """ + + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) From 3f594145e535aa5089056ea8b4f8abd3fbb5205e Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:22:48 -0500 Subject: [PATCH 21/50] hashable Co-authored-by: keewis --- pint_xarray/checking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 057933f1..98f4575c 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -23,7 +23,7 @@ def expects(*args_units, return_units=None, **kwargs_units): func : callable Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, and may optionally return one or more xarray.DataArrays or numpy-like arrays. - *args_units : unit-like or mapping of unit-like or None, optional + *args_units : unit-like or mapping of hashable to unit-like, optional Units to expect for each positional argument given to func. The decorator will first check that arguments passed to the decorated function possess these specific units From b281674b3842a56a18d1d27c8ddbde00150748b8 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:22:57 -0500 Subject: [PATCH 22/50] hashable Co-authored-by: keewis --- pint_xarray/checking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 98f4575c..37e69d93 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -38,7 +38,7 @@ def expects(*args_units, return_units=None, **kwargs_units): A value of None indicates not to attach any units to that return value (suitable for flags and other non-data results). - kwargs_units : mapping of unit-like or None, Optional + kwargs_units : mapping of hashable to unit-like, optional Unit to expect for each keyword argument given to func. The decorator will first check that arguments passed to the decorated function possess these specific units From c6691057d7bb601466fb87f9e3bc0c620a2fddf4 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:23:08 -0500 Subject: [PATCH 23/50] hashable Co-authored-by: keewis --- pint_xarray/checking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 37e69d93..39a10376 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -32,7 +32,7 @@ def expects(*args_units, return_units=None, **kwargs_units): A value of None indicates not to check that argument for units (suitable for flags and other non-data arguments). - return_units : unit-like or mapping of unit-like or list of unit-like or mapping of unit-like or None, optional + return_units : unit-like or mapping of hashable to unit-like or list of unit-like or list of mapping of hashable to unit-like, optional The expected units of the returned value(s), either as a single unit or as a list of units. The decorator will attach these units to the variables returned from the function. From 9ac88877a75103d2ef81ee4a3639a94b00fbcb05 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:39:53 -0500 Subject: [PATCH 24/50] dict comprehension Co-authored-by: keewis --- pint_xarray/checking.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 39a10376..9d6b4e02 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -102,10 +102,10 @@ def _unit_checking_wrapper(*args, **kwargs): converted_arg = _check_or_convert_to_then_strip(arg, arg_unit) converted_args.append(converted_arg) - converted_kwargs = {} - for key, val in kwargs.items(): - kwarg_unit = kwargs_units.get(key, None) - converted_kwargs[key] = _check_or_convert_to_then_strip(val, kwarg_unit) + converted_kwargs = { + key: _check_or_convert_to_then_strip(val, kwargs_units.get(key, None) + for key, val in kwargs.items() + } results = func(*converted_args, **converted_kwargs) From 0a6447d4a3571ec79a555b44d7edaafb3aa0d1e9 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:40:21 -0500 Subject: [PATCH 25/50] list comprehension Co-authored-by: keewis --- pint_xarray/checking.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 9d6b4e02..3aecb9c8 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -97,10 +97,10 @@ def _unit_checking_wrapper(*args, **kwargs): f"{len(args)} arguments were passed, but {len(args_units)} arguments were expected" ) - converted_args = [] - for arg, arg_unit in zip(args, args_units): - converted_arg = _check_or_convert_to_then_strip(arg, arg_unit) - converted_args.append(converted_arg) + converted_args = [ + _check_or_convert_to_then_strip(arg, arg_unit) + for arg, arg_unit in zip(args, args_units) + ] converted_kwargs = { key: _check_or_convert_to_then_strip(val, kwargs_units.get(key, None) From c29f935d0128e2afec3bddd0521c8dc3295c7afa Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:40:50 -0500 Subject: [PATCH 26/50] unindent if/else Co-authored-by: keewis --- pint_xarray/checking.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 3aecb9c8..7f091878 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -158,18 +158,17 @@ def _check_or_convert_to_then_strip(obj, units): if units is None: # allow for passing through non-numerical arguments return obj + elif isinstance(obj, Quantity): + converted = obj.to(units) + return converted.magnitude + elif isinstance(obj, (DataArray, Dataset)): + converted = obj.pint.to(units) + return converted.pint.dequantify() else: - if isinstance(obj, Quantity): - converted = obj.to(units) - return converted.magnitude - elif isinstance(obj, (DataArray, Dataset)): - converted = obj.pint.to(units) - return converted.pint.dequantify() - else: - raise TypeError( - "Can only expect units for arguments of type xarray.DataArray or pint.Quantity," - f"not {type(obj)}" - ) + raise TypeError( + "Can only expect units for arguments of type xarray.DataArray," + f" xarray.Dataset, or pint.Quantity, not {type(obj)}" + ) def _attach_units(obj, units): From 81913a6bd0d1d60f84729cc1c013154654f2142a Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 14 Dec 2021 14:42:14 -0500 Subject: [PATCH 27/50] missing parenthesis --- pint_xarray/checking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 7f091878..b4fefe99 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -103,7 +103,7 @@ def _unit_checking_wrapper(*args, **kwargs): ] converted_kwargs = { - key: _check_or_convert_to_then_strip(val, kwargs_units.get(key, None) + key: _check_or_convert_to_then_strip(val, kwargs_units.get(key, None)) for key, val in kwargs.items() } From 4de6f4d66e534c8e769fdf874c778afff34b5196 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 14 Dec 2021 15:24:54 -0500 Subject: [PATCH 28/50] simplify if/else logic for checking there were actually results --- pint_xarray/checking.py | 178 ++++++++++++++++++++-------------------- 1 file changed, 88 insertions(+), 90 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index b4fefe99..53318ba2 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -8,70 +8,70 @@ def expects(*args_units, return_units=None, **kwargs_units): """ - Decorator which ensures the inputs and outputs of the decorated function are expressed in the expected units. - - Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and - then stripped of their units before being passed into the undecorated function. Therefore the undecorated function - should expect unquantified DataArrays or numpy-like arrays, but with the values expressed in specific units. - - Note that the coordinates of input DataArrays are not checked, only the data. - So if your decorated function uses coordinates and you wish to check their units, - you should pass the coordinates of interest as separate arguments. - - Parameters - ---------- - func : callable - Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, - and may optionally return one or more xarray.DataArrays or numpy-like arrays. - *args_units : unit-like or mapping of hashable to unit-like, optional - Units to expect for each positional argument given to func. - - The decorator will first check that arguments passed to the decorated function possess these specific units - (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude - to the wrapped function. - - A value of None indicates not to check that argument for units (suitable for flags and other non-data - arguments). - return_units : unit-like or mapping of hashable to unit-like or list of unit-like or list of mapping of hashable to unit-like, optional - The expected units of the returned value(s), either as a single unit or as a list of units. The decorator - will attach these units to the variables returned from the function. - - A value of None indicates not to attach any units to that return value (suitable for flags and other - non-data results). - kwargs_units : mapping of hashable to unit-like, optional - Unit to expect for each keyword argument given to func. - - The decorator will first check that arguments passed to the decorated function possess these specific units - (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude - to the wrapped function. - - A value of None indicates not to check that argument for units (suitable for flags and other non-data - arguments). - - Returns - ------- - return_values : Any - Return values of the wrapped function, either a single value or a tuple of values. These will be given units - according to return_units. - - Raises - ------ - TypeError - If an argument or return value has a specified unit, but is not an xarray.DataArray or pint.Quantity. - Also thrown if any of the units are not a valid type, or if the number of arguments or return values does not - match the number of units specified. - - Examples - -------- - - Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean). - - >>> @expects("deg C") - ... def above_freezing(temp): - ... return temp > 0 - - - TODO: example where we check units of an optional weighted kwarg + Decorator which ensures the inputs and outputs of the decorated function are expressed in the expected units. + + Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and + then stripped of their units before being passed into the undecorated function. Therefore the undecorated function + should expect unquantified DataArrays or numpy-like arrays, but with the values expressed in specific units. + + Note that the coordinates of input DataArrays are not checked, only the data. + So if your decorated function uses coordinates and you wish to check their units, + you should pass the coordinates of interest as separate arguments. + + Parameters + ---------- + func : callable + Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, + and may optionally return one or more xarray.DataArrays or numpy-like arrays. + *args_units : unit-like or mapping of hashable to unit-like, optional + Units to expect for each positional argument given to func. + + The decorator will first check that arguments passed to the decorated function possess these specific units + (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude + to the wrapped function. + + A value of None indicates not to check that argument for units (suitable for flags and other non-data + arguments). + return_units : unit-like or mapping of hashable to unit-like or list of unit-like or list of mapping of hashable to unit-like, optional + The expected units of the returned value(s), either as a single unit or as a list of units. The decorator + will attach these units to the variables returned from the function. + + A value of None indicates not to attach any units to that return value (suitable for flags and other + non-data results). + kwargs_units : mapping of hashable to unit-like, optional + Unit to expect for each keyword argument given to func. + + The decorator will first check that arguments passed to the decorated function possess these specific units + (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude + to the wrapped function. + + A value of None indicates not to check that argument for units (suitable for flags and other non-data + arguments). + + Returns + ------- + return_values : Any + Return values of the wrapped function, either a single value or a tuple of values. These will be given units + according to return_units. + + Raises + ------ + TypeError + If an argument or return value has a specified unit, but is not an xarray.DataArray or pint.Quantity. + Also thrown if any of the units are not a valid type, or if the number of arguments or return values does not + match the number of units specified. + + Examples + -------- + + Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean). + + >>> @expects("deg C") + ... def above_freezing(temp): + ... return temp > 0 + + + TODO: example where we check units of an optional weighted kwarg """ # Check types of args_units, kwargs_units, and return_units @@ -109,35 +109,33 @@ def _unit_checking_wrapper(*args, **kwargs): results = func(*converted_args, **converted_kwargs) - if results is None: - if return_units: - raise TypeError( - "Expected function to return something, but function returned None" - ) + if return_units is None: + # ignore types and units of return values + return results + elif results is None: + raise TypeError( + "Expected function to return something, but function returned None" + ) else: - if return_units is None: - # ignore types and units of return values - return results - else: - # handle case of function returning only one result by promoting to 1-element tuple - return_units_iterable = tuple(always_iterable(return_units)) - results_iterable = tuple(always_iterable(results)) - - # check same number of things were returned as expected - if len(results_iterable) != len(return_units_iterable): - raise TypeError( - f"{len(results_iterable)} return values were received, but {len(return_units_iterable)} " - "return values were expected" - ) - - converted_results = _attach_multiple_units( - results_iterable, return_units_iterable + # handle case of function returning only one result by promoting to 1-element tuple + return_units_iterable = tuple(always_iterable(return_units)) + results_iterable = tuple(always_iterable(results)) + + # check same number of things were returned as expected + if len(results_iterable) != len(return_units_iterable): + raise TypeError( + f"{len(results_iterable)} return values were received, but {len(return_units_iterable)} " + "return values were expected" ) - if len(converted_results) == 1: - return tuple(converted_results)[0] - else: - return tuple(converted_results) + converted_results = _attach_multiple_units( + results_iterable, return_units_iterable + ) + + if len(converted_results) == 1: + return tuple(converted_results)[0] + else: + return tuple(converted_results) return _unit_checking_wrapper From 83e422fafc30a6e6ea2d2da6a1135eae4a7c3004 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 14 Dec 2021 15:37:41 -0500 Subject: [PATCH 29/50] return results immediately if a tuple --- pint_xarray/checking.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 53318ba2..0510ee36 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -132,10 +132,13 @@ def _unit_checking_wrapper(*args, **kwargs): results_iterable, return_units_iterable ) - if len(converted_results) == 1: - return tuple(converted_results)[0] + if isinstance(results, tuple): + return converted_results else: - return tuple(converted_results) + if len(converted_results) == 1: + return converted_results[0] + else: + return converted_results return _unit_checking_wrapper From 37c3fbcded4351f5ad25ab9e1ff98081959611ce Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 14 Dec 2021 16:30:07 -0500 Subject: [PATCH 30/50] allow for returning Datasets from wrapped funciton --- pint_xarray/checking.py | 12 ++++++++---- pint_xarray/tests/test_checking.py | 17 ++++++++++++++++- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 0510ee36..0d4a5d8f 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -12,7 +12,7 @@ def expects(*args_units, return_units=None, **kwargs_units): Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and then stripped of their units before being passed into the undecorated function. Therefore the undecorated function - should expect unquantified DataArrays or numpy-like arrays, but with the values expressed in specific units. + should expect unquantified DataArrays, Datasets, or numpy-like arrays, but with the values expressed in specific units. Note that the coordinates of input DataArrays are not checked, only the data. So if your decorated function uses coordinates and you wish to check their units, @@ -118,8 +118,12 @@ def _unit_checking_wrapper(*args, **kwargs): ) else: # handle case of function returning only one result by promoting to 1-element tuple - return_units_iterable = tuple(always_iterable(return_units)) - results_iterable = tuple(always_iterable(results)) + return_units_iterable = tuple( + always_iterable(return_units, base_type=(str, dict)) + ) + results_iterable = tuple( + always_iterable(results, base_type=(str, Dataset)) + ) # check same number of things were returned as expected if len(results_iterable) != len(return_units_iterable): @@ -174,7 +178,7 @@ def _check_or_convert_to_then_strip(obj, units): def _attach_units(obj, units): """Attaches units, but can also create pint.Quantity objects from numpy scalars""" - if isinstance(obj, DataArray): + if isinstance(obj, (DataArray, Dataset)): return obj.pint.quantify(units) else: return Quantity(obj, units=units) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 4de43938..51224772 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -146,7 +146,7 @@ def freq(period): with pytest.raises(TypeError, match="function returned None"): freq(p) - def test_unit_dict(self): + def test_input_unit_dict(self): @expects({"m": "kg", "a": "m / s^2"}, return_units="newtons") def second_law(ds): return ds["m"] * ds["a"] @@ -158,3 +158,18 @@ def second_law(ds): expected_da = (m_da * a_da).pint.to("newtons") result_da = second_law(ds) assert result_da == expected_da + + def test_return_dataset(self): + @expects({"m": "kg", "a": "m / s^2"}, return_units=[{"f": "newtons"}]) + def second_law(ds): + f_da = ds["m"] * ds["a"] + return xr.Dataset({"f": f_da}) + + m_da = xr.DataArray(0.1).pint.quantify(units="tons") + a_da = xr.DataArray(10).pint.quantify(units="feet / second^2") + ds = xr.Dataset({"m": m_da, "a": a_da}) + + expected_da = m_da * a_da + expected_ds = xr.Dataset({"f": expected_da}).pint.to({"f": "newtons"}) + result_ds = second_law(ds) + assert result_ds == expected_ds From 9c19af02d87cb8fd789fb074e5b442323be65355 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Fri, 14 Jan 2022 16:16:52 -0500 Subject: [PATCH 31/50] Update docs/api.rst Co-authored-by: keewis --- docs/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index d172555d..55d72379 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -66,7 +66,7 @@ DataArray xarray.DataArray.pint.interpolate_na Checking -------- +-------- .. autosummary:: :toctree: generated/ From 0b5c7c0811b352f518601fd7f71c8e9f49730eed Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 14 Jan 2022 16:17:57 -0500 Subject: [PATCH 32/50] correct indentation of docstring --- pint_xarray/checking.py | 106 ++++++++++++++++++++-------------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 0d4a5d8f..aa2ebd5a 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -8,70 +8,70 @@ def expects(*args_units, return_units=None, **kwargs_units): """ - Decorator which ensures the inputs and outputs of the decorated function are expressed in the expected units. - - Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and - then stripped of their units before being passed into the undecorated function. Therefore the undecorated function - should expect unquantified DataArrays, Datasets, or numpy-like arrays, but with the values expressed in specific units. - - Note that the coordinates of input DataArrays are not checked, only the data. - So if your decorated function uses coordinates and you wish to check their units, - you should pass the coordinates of interest as separate arguments. - - Parameters - ---------- - func : callable - Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, - and may optionally return one or more xarray.DataArrays or numpy-like arrays. - *args_units : unit-like or mapping of hashable to unit-like, optional - Units to expect for each positional argument given to func. - - The decorator will first check that arguments passed to the decorated function possess these specific units - (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude - to the wrapped function. - - A value of None indicates not to check that argument for units (suitable for flags and other non-data - arguments). + Decorator which ensures the inputs and outputs of the decorated function are expressed in the expected units. + + Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and + then stripped of their units before being passed into the undecorated function. Therefore the undecorated function + should expect unquantified DataArrays, Datasets, or numpy-like arrays, but with the values expressed in specific units. + + Note that the coordinates of input DataArrays are not checked, only the data. + So if your decorated function uses coordinates and you wish to check their units, + you should pass the coordinates of interest as separate arguments. + + Parameters + ---------- + func : callable + Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, + and may optionally return one or more xarray.DataArrays or numpy-like arrays. + *args_units : unit-like or mapping of hashable to unit-like, optional + Units to expect for each positional argument given to func. + + The decorator will first check that arguments passed to the decorated function possess these specific units + (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude + to the wrapped function. + + A value of None indicates not to check that argument for units (suitable for flags and other non-data + arguments). return_units : unit-like or mapping of hashable to unit-like or list of unit-like or list of mapping of hashable to unit-like, optional - The expected units of the returned value(s), either as a single unit or as a list of units. The decorator - will attach these units to the variables returned from the function. + The expected units of the returned value(s), either as a single unit or as a list of units. The decorator + will attach these units to the variables returned from the function. - A value of None indicates not to attach any units to that return value (suitable for flags and other - non-data results). - kwargs_units : mapping of hashable to unit-like, optional - Unit to expect for each keyword argument given to func. + A value of None indicates not to attach any units to that return value (suitable for flags and other + non-data results). + kwargs_units : mapping of hashable to unit-like, optional + Unit to expect for each keyword argument given to func. - The decorator will first check that arguments passed to the decorated function possess these specific units - (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude - to the wrapped function. + The decorator will first check that arguments passed to the decorated function possess these specific units + (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude + to the wrapped function. - A value of None indicates not to check that argument for units (suitable for flags and other non-data - arguments). + A value of None indicates not to check that argument for units (suitable for flags and other non-data + arguments). - Returns - ------- - return_values : Any - Return values of the wrapped function, either a single value or a tuple of values. These will be given units - according to return_units. + Returns + ------- + return_values : Any + Return values of the wrapped function, either a single value or a tuple of values. These will be given units + according to return_units. - Raises - ------ - TypeError - If an argument or return value has a specified unit, but is not an xarray.DataArray or pint.Quantity. - Also thrown if any of the units are not a valid type, or if the number of arguments or return values does not - match the number of units specified. + Raises + ------ + TypeError + If an argument or return value has a specified unit, but is not an xarray.DataArray or pint.Quantity. + Also thrown if any of the units are not a valid type, or if the number of arguments or return values does not + match the number of units specified. - Examples - -------- + Examples + -------- - Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean). + Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean). - >>> @expects("deg C") - ... def above_freezing(temp): - ... return temp > 0 + >>> @expects("deg C") + ... def above_freezing(temp): + ... return temp > 0 - TODO: example where we check units of an optional weighted kwarg + TODO: example where we check units of an optional weighted kwarg """ # Check types of args_units, kwargs_units, and return_units From 0f503056e7c7d4a8151ad3ae88ede3574f23a65e Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 14 Jan 2022 16:59:12 -0500 Subject: [PATCH 33/50] use inspects to check number of arguments passed to decorated function --- pint_xarray/checking.py | 34 ++++++++++++++++++++++-------- pint_xarray/tests/test_checking.py | 17 +++++++-------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index aa2ebd5a..a410c0a3 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -1,4 +1,6 @@ import functools +import inspect +from inspect import Parameter from pint import Quantity, Unit from xarray import DataArray, Dataset @@ -69,11 +71,10 @@ def expects(*args_units, return_units=None, **kwargs_units): >>> @expects("deg C") ... def above_freezing(temp): ... return temp > 0 - - - TODO: example where we check units of an optional weighted kwarg """ + # TODO: example where we check units of an optional weighted kwarg + # Check types of args_units, kwargs_units, and return_units all_units = list(args_units) + list(kwargs_units.values()) if isinstance(return_units, list): @@ -88,15 +89,30 @@ def expects(*args_units, return_units=None, **kwargs_units): _check_valid_unit_type(a) def _expects_decorator(func): + + # check same number of arguments were passed as expected + sig = inspect.signature(func) + positional_args = ( + Parameter.POSITIONAL_ONLY, + Parameter.VAR_POSITIONAL, + Parameter.POSITIONAL_OR_KEYWORD, + ) + n_args = len( + [ + param + for param in sig.parameters.values() + if param.kind in positional_args and param.default is param.empty + ] + ) + if n_args != len(args_units): + raise TypeError( + f"The `expects` decorator used expects {len(args_units)} arguments, but a function expecting {n_args} " + f"arguments was wrapped" + ) + @functools.wraps(func) def _unit_checking_wrapper(*args, **kwargs): - # check same number of arguments were passed as expected - if len(args) != len(args_units): - raise TypeError( - f"{len(args)} arguments were passed, but {len(args_units)} arguments were expected" - ) - converted_args = [ _check_or_convert_to_then_strip(arg, arg_unit) for arg, arg_unit in zip(args, args_units) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 51224772..7297889e 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -115,15 +115,14 @@ def test_unquantified_arrays(self): raise NotImplementedError def test_wrong_number_of_args(self): - @expects("kg", return_units="newtons") - def second_law(m, a): - return m * a - - m_q = pint.Quantity(0.1, units="tons") - a_q = pint.Quantity(10, units="feet / second^2") - - with pytest.raises(TypeError, match="1 arguments were expected"): - second_law(m_q, a_q) + with pytest.raises( + TypeError, + match="expects 1 arguments, but a function expecting 2 arguments was wrapped", + ): + + @expects("kg", return_units="newtons") + def second_law(m, a): + return m * a def test_wrong_number_of_return_values(self): @expects("kg", "m / s^2", return_units=["newtons", "joules"]) From 57d341e01f998ff383dede0e8bdeb098116ac981 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 15 Jan 2022 18:20:14 +0100 Subject: [PATCH 34/50] reformat the docstring --- pint_xarray/checking.py | 92 ++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index a410c0a3..a665ac32 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -10,63 +10,81 @@ def expects(*args_units, return_units=None, **kwargs_units): """ - Decorator which ensures the inputs and outputs of the decorated function are expressed in the expected units. - - Arguments to the decorated function are checked for the specified units, converting to those units if necessary, and - then stripped of their units before being passed into the undecorated function. Therefore the undecorated function - should expect unquantified DataArrays, Datasets, or numpy-like arrays, but with the values expressed in specific units. - - Note that the coordinates of input DataArrays are not checked, only the data. - So if your decorated function uses coordinates and you wish to check their units, - you should pass the coordinates of interest as separate arguments. + Decorator which ensures the inputs and outputs of the decorated + function are expressed in the expected units. + + Arguments to the decorated function are checked for the specified + units, converting to those units if necessary, and then stripped + of their units before being passed into the undecorated + function. Therefore the undecorated function should expect + unquantified DataArrays, Datasets, or numpy-like arrays, but with + the values expressed in specific units. + + .. note:: + The coordinates of input DataArrays are not checked, only the + data. So if your decorated function uses coordinates and you + wish to check their units, you should pass the coordinates of + interest as separate arguments. Parameters ---------- func : callable - Function to decorate. which accepts zero or more xarray.DataArrays or numpy-like arrays as inputs, - and may optionally return one or more xarray.DataArrays or numpy-like arrays. + Function to decorate, which accepts zero or more + xarray.DataArrays or numpy-like arrays as inputs, and may + optionally return one or more xarray.DataArrays or numpy-like + arrays. *args_units : unit-like or mapping of hashable to unit-like, optional - Units to expect for each positional argument given to func. - - The decorator will first check that arguments passed to the decorated function possess these specific units - (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude - to the wrapped function. - - A value of None indicates not to check that argument for units (suitable for flags and other non-data - arguments). - return_units : unit-like or mapping of hashable to unit-like or list of unit-like or list of mapping of hashable to unit-like, optional - The expected units of the returned value(s), either as a single unit or as a list of units. The decorator - will attach these units to the variables returned from the function. - - A value of None indicates not to attach any units to that return value (suitable for flags and other - non-data results). + Units to expect for each positional argument given to func. + + The decorator will first check that arguments passed to the + decorated function possess these specific units (or will + attempt to convert the argument to these units), then will + strip the units before passing the magnitude to the wrapped + function. + + A value of None indicates not to check that argument for units + (suitable for flags and other non-data arguments). + return_units : unit-like or list of unit-like or mapping of hashable to unit-like \ + or list of mapping of hashable to unit-like, optional + The expected units of the returned value(s), either as a + single unit or as a list of units. The decorator will attach + these units to the variables returned from the function. + + A value of None indicates not to attach any units to that + return value (suitable for flags and other non-data results). kwargs_units : mapping of hashable to unit-like, optional - Unit to expect for each keyword argument given to func. + Unit to expect for each keyword argument given to func. - The decorator will first check that arguments passed to the decorated function possess these specific units - (or will attempt to convert the argument to these units), then will strip the units before passing the magnitude - to the wrapped function. + The decorator will first check that arguments passed to the + decorated function possess these specific units (or will + attempt to convert the argument to these units), then will + strip the units before passing the magnitude to the wrapped + function. - A value of None indicates not to check that argument for units (suitable for flags and other non-data - arguments). + A value of None indicates not to check that argument for units + (suitable for flags and other non-data arguments). Returns ------- return_values : Any - Return values of the wrapped function, either a single value or a tuple of values. These will be given units - according to return_units. + Return values of the wrapped function, either a single value + or a tuple of values. These will be given units according to + return_units. Raises ------ TypeError - If an argument or return value has a specified unit, but is not an xarray.DataArray or pint.Quantity. - Also thrown if any of the units are not a valid type, or if the number of arguments or return values does not - match the number of units specified. + If an argument or return value has a specified unit, but is + not an xarray.DataArray or pint.Quantity. Also thrown if any + of the units are not a valid type, or if the number of + arguments or return values does not match the number of units + specified. Examples -------- - Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean). + Decorating a function which takes one quantified input, but + returns a non-data value (in this case a boolean). >>> @expects("deg C") ... def above_freezing(temp): From 8845b779884f6fe1838c893e1fac018793d2583f Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 16 Jan 2022 12:45:50 +0100 Subject: [PATCH 35/50] update the definition of unit-like --- docs/terminology.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/terminology.rst b/docs/terminology.rst index 4fe6534d..4532063e 100644 --- a/docs/terminology.rst +++ b/docs/terminology.rst @@ -5,6 +5,7 @@ Terminology unit-like A `pint`_ unit definition, as accepted by :py:class:`pint.Unit`. - May be either a :py:class:`str` or a :py:class:`pint.Unit` instance. + May be either a :py:class:`str`, a :py:class:`pint.Unit` + instance, or :py:obj:`None`. .. _pint: https://pint.readthedocs.io/en/stable From bc41425847862bcb31f1fe6de04774a8c763ae89 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 18 Jan 2022 13:47:09 -0500 Subject: [PATCH 36/50] simplify if/else statement --- pint_xarray/checking.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index a410c0a3..2ff737b0 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -154,11 +154,10 @@ def _unit_checking_wrapper(*args, **kwargs): if isinstance(results, tuple): return converted_results + elif len(converted_results) == 1: + return converted_results[0] else: - if len(converted_results) == 1: - return converted_results[0] - else: - return converted_results + return converted_results return _unit_checking_wrapper From 03503080a37a300bed387ef287397bb54c461cda Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Wed, 19 Jan 2022 16:40:48 -0500 Subject: [PATCH 37/50] check units in .to instead --- pint_xarray/checking.py | 20 +------------------- pint_xarray/tests/test_checking.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 71a31c72..98c141e3 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -2,7 +2,7 @@ import inspect from inspect import Parameter -from pint import Quantity, Unit +from pint import Quantity from xarray import DataArray, Dataset from .accessors import PintDataArrayAccessor # noqa @@ -93,19 +93,6 @@ def expects(*args_units, return_units=None, **kwargs_units): # TODO: example where we check units of an optional weighted kwarg - # Check types of args_units, kwargs_units, and return_units - all_units = list(args_units) + list(kwargs_units.values()) - if isinstance(return_units, list): - all_units = all_units + return_units - elif return_units: - all_units = all_units + [return_units] - for a in all_units: - if isinstance(a, dict): - for u in a.values(): - _check_valid_unit_type(u) - else: - _check_valid_unit_type(a) - def _expects_decorator(func): # check same number of arguments were passed as expected @@ -182,11 +169,6 @@ def _unit_checking_wrapper(*args, **kwargs): return _expects_decorator -def _check_valid_unit_type(a): - if not isinstance(a, (Unit, str)) and a is not None: - raise TypeError(f"{a} is not a valid type for a unit, it is of type {type(a)}") - - def _check_or_convert_to_then_strip(obj, units): """ Checks the object is of a valid type (Quantity or DataArray), then attempts to convert it to the specified units, diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 7297889e..5bca47bc 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -97,14 +97,18 @@ def finite_difference(a, type): finite_difference(t, "centered") @pytest.mark.parametrize( - "arg_units, return_units", [(True, None), ("seconds", 6), ("seconds", [6])] + "arg_units, return_units", + [("nonsense", "Hertz"), ("seconds", 6), ("seconds", [6])], ) def test_invalid_unit_types(self, arg_units, return_units): - with pytest.raises(TypeError): + @expects(arg_units, return_units=return_units) + def freq(period): + return 1 / period + + q = pint.Quantity(1.0, units="seconds") - @expects(arg_units, return_units=return_units) - def freq(period): - ... + with pytest.raises((TypeError, pint.errors.UndefinedUnitError)): + freq(q) @pytest.mark.xfail def test_invalid_return_types(self): From 3a24a730a60f633643bce070f8decba7d52d8015 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Wed, 19 Jan 2022 16:47:08 -0500 Subject: [PATCH 38/50] remove extra xfailed test --- pint_xarray/tests/test_checking.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 5bca47bc..e1749292 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -110,10 +110,6 @@ def freq(period): with pytest.raises((TypeError, pint.errors.UndefinedUnitError)): freq(q) - @pytest.mark.xfail - def test_invalid_return_types(self): - raise NotImplementedError - @pytest.mark.xfail def test_unquantified_arrays(self): raise NotImplementedError From 19fd6e0c0658007f14635add07d66c75bdda2874 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 20 Jan 2022 11:26:27 -0500 Subject: [PATCH 39/50] test raises on unquantified input --- pint_xarray/tests/test_checking.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index e1749292..395f271b 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -110,9 +110,18 @@ def freq(period): with pytest.raises((TypeError, pint.errors.UndefinedUnitError)): freq(q) - @pytest.mark.xfail def test_unquantified_arrays(self): - raise NotImplementedError + @expects("seconds", return_units="Hertz") + def freq(period): + return 1 / period + + da = xr.DataArray(10) + + with pytest.raises( + ValueError, + match="cannot convert a non-quantity", + ): + freq(da) def test_wrong_number_of_args(self): with pytest.raises( From d2d74e47fba73f9703f9b6cfc7847c21a47696eb Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 20 Jan 2022 12:55:44 -0500 Subject: [PATCH 40/50] add example of function which optionally accepts dimensionless weights --- pint_xarray/checking.py | 13 +++++++++++-- pint_xarray/tests/test_checking.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 98c141e3..572e97b2 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -89,9 +89,18 @@ def expects(*args_units, return_units=None, **kwargs_units): >>> @expects("deg C") ... def above_freezing(temp): ... return temp > 0 - """ - # TODO: example where we check units of an optional weighted kwarg + Decorating a function which allows any dimensions for the array, but also + accepts an optional `weights` keyword argument, which must be dimensionless. + + >>> @expects(None, weights="dimensionless") + ... def mean(da, weights=None): + ... if weights: + ... return da.weighted(weights=weights).mean() + ... else: + ... return da.mean() + + """ def _expects_decorator(func): diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 395f271b..5eb7d891 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -48,6 +48,21 @@ def freq(wavelength, c=None): f_da = freq(w_da) assert f_da.pint.units == pint.Unit("hertz") + def test_weighted_kwarg(self): + @expects(None, weights="dimensionless") + def mean(da, weights=None): + if weights is not None: + return da.weighted(weights=weights).mean() + else: + return da.mean() + + d = xr.DataArray([1, 2, 3]).pint.quantify(units="metres") + w = xr.DataArray([0.1, 0.7, 0.2]).pint.quantify(units="dimensionless") + + result = mean(d, weights=w) + expected = xr.DataArray(0.21).pint.quantify("metres") + assert result.pint.units == expected.pint.units + def test_single_return_value(self): @expects("kg", "m / s^2", return_units="newtons") def second_law(m, a): From 85b982c4289f5223a9e1c6e6b5ec43c7c4c911d8 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 20 Sep 2022 17:55:22 +0200 Subject: [PATCH 41/50] rewrite using inspect.Signature's bind and bind_partial --- pint_xarray/checking.py | 245 +++++++++++++++++++++------------------- 1 file changed, 127 insertions(+), 118 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 572e97b2..aadc09f1 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -8,6 +8,119 @@ from .accessors import PintDataArrayAccessor # noqa +def detect_missing_params(params, units): + """detect parameters for which no units were specified""" + variable_params = { + Parameter.VAR_POSITIONAL, + Parameter.VAR_KEYWORD, + } + + return { + name + for name, param in params.items() + if name not in units.arguments and param.kind not in variable_params + } + + +def convert_and_strip_args(args, units): + pass + + +def convert_and_strip_kwargs(kwargs, units): + pass + + +def attach_return_units(results, units): + if units is None: + # ignore types and units of return values + return results + elif results is None: + raise TypeError( + "Expected function to return something, but function returned None" + ) + else: + # handle case of function returning only one result by promoting to 1-element tuple + return_units_iterable = tuple(always_iterable(units, base_type=(str, dict))) + results_iterable = tuple(always_iterable(results, base_type=(str, Dataset))) + + # check same number of things were returned as expected + if len(results_iterable) != len(return_units_iterable): + raise TypeError( + f"{len(results_iterable)} return values were received, but {len(return_units_iterable)} " + "return values were expected" + ) + + converted_results = _attach_multiple_units( + results_iterable, return_units_iterable + ) + + if isinstance(results, tuple): + return converted_results + elif len(converted_results) == 1: + return converted_results[0] + else: + return converted_results + + +def _check_or_convert_to_then_strip(obj, units): + """ + Checks the object is of a valid type (Quantity or DataArray), then attempts to convert it to the specified units, + then strips the units from it. + """ + + if units is None: + # allow for passing through non-numerical arguments + return obj + elif isinstance(obj, Quantity): + converted = obj.to(units) + return converted.magnitude + elif isinstance(obj, (DataArray, Dataset)): + converted = obj.pint.to(units) + return converted.pint.dequantify() + else: + raise TypeError( + "Can only expect units for arguments of type xarray.DataArray," + f" xarray.Dataset, or pint.Quantity, not {type(obj)}" + ) + + +def _attach_units(obj, units): + """Attaches units, but can also create pint.Quantity objects from numpy scalars""" + if isinstance(obj, (DataArray, Dataset)): + return obj.pint.quantify(units) + else: + return Quantity(obj, units=units) + + +def _attach_multiple_units(objects, units): + """Attaches list of units to list of objects elementwise""" + converted_objects = [_attach_units(obj, unit) for obj, unit in zip(objects, units)] + return converted_objects + + +def always_iterable(obj, base_type=(str, bytes)): + """ + If *obj* is iterable, return an iterator over its items, + If *obj* is not iterable, return a one-item iterable containing *obj*, + If *obj* is ``None``, return an empty iterable. + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + Copied from more_itertools. + """ + + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + def expects(*args_units, return_units=None, **kwargs_units): """ Decorator which ensures the inputs and outputs of the decorated @@ -106,132 +219,28 @@ def _expects_decorator(func): # check same number of arguments were passed as expected sig = inspect.signature(func) - positional_args = ( - Parameter.POSITIONAL_ONLY, - Parameter.VAR_POSITIONAL, - Parameter.POSITIONAL_OR_KEYWORD, - ) - n_args = len( - [ - param - for param in sig.parameters.values() - if param.kind in positional_args and param.default is param.empty - ] - ) - if n_args != len(args_units): - raise TypeError( - f"The `expects` decorator used expects {len(args_units)} arguments, but a function expecting {n_args} " - f"arguments was wrapped" - ) + + params = sig.parameters + + bound_units = sig.bind_partial(*args_units, **kwargs_units) + + missing_params = detect_missing_params(params, bound_units) + if missing_params: + raise ValueError(f"no units for {missing_params}") @functools.wraps(func) def _unit_checking_wrapper(*args, **kwargs): + bound = sig.bind(*args, **kwargs) - converted_args = [ - _check_or_convert_to_then_strip(arg, arg_unit) - for arg, arg_unit in zip(args, args_units) - ] - - converted_kwargs = { - key: _check_or_convert_to_then_strip(val, kwargs_units.get(key, None)) - for key, val in kwargs.items() - } + converted_args = convert_and_strip_args(bound.args, bound_units.args) + converted_kwargs = convert_and_strip_kwargs( + bound.kwargs, bound_units.kwargs + ) results = func(*converted_args, **converted_kwargs) - if return_units is None: - # ignore types and units of return values - return results - elif results is None: - raise TypeError( - "Expected function to return something, but function returned None" - ) - else: - # handle case of function returning only one result by promoting to 1-element tuple - return_units_iterable = tuple( - always_iterable(return_units, base_type=(str, dict)) - ) - results_iterable = tuple( - always_iterable(results, base_type=(str, Dataset)) - ) - - # check same number of things were returned as expected - if len(results_iterable) != len(return_units_iterable): - raise TypeError( - f"{len(results_iterable)} return values were received, but {len(return_units_iterable)} " - "return values were expected" - ) - - converted_results = _attach_multiple_units( - results_iterable, return_units_iterable - ) - - if isinstance(results, tuple): - return converted_results - elif len(converted_results) == 1: - return converted_results[0] - else: - return converted_results + return attach_return_units(results, return_units) return _unit_checking_wrapper return _expects_decorator - - -def _check_or_convert_to_then_strip(obj, units): - """ - Checks the object is of a valid type (Quantity or DataArray), then attempts to convert it to the specified units, - then strips the units from it. - """ - - if units is None: - # allow for passing through non-numerical arguments - return obj - elif isinstance(obj, Quantity): - converted = obj.to(units) - return converted.magnitude - elif isinstance(obj, (DataArray, Dataset)): - converted = obj.pint.to(units) - return converted.pint.dequantify() - else: - raise TypeError( - "Can only expect units for arguments of type xarray.DataArray," - f" xarray.Dataset, or pint.Quantity, not {type(obj)}" - ) - - -def _attach_units(obj, units): - """Attaches units, but can also create pint.Quantity objects from numpy scalars""" - if isinstance(obj, (DataArray, Dataset)): - return obj.pint.quantify(units) - else: - return Quantity(obj, units=units) - - -def _attach_multiple_units(objects, units): - """Attaches list of units to list of objects elementwise""" - converted_objects = [_attach_units(obj, unit) for obj, unit in zip(objects, units)] - return converted_objects - - -def always_iterable(obj, base_type=(str, bytes)): - """ - If *obj* is iterable, return an iterator over its items, - If *obj* is not iterable, return a one-item iterable containing *obj*, - If *obj* is ``None``, return an empty iterable. - If *base_type* is set, objects for which ``isinstance(obj, base_type)`` - returns ``True`` won't be considered iterable. - - Copied from more_itertools. - """ - - if obj is None: - return iter(()) - - if (base_type is not None) and isinstance(obj, base_type): - return iter((obj,)) - - try: - return iter(obj) - except TypeError: - return iter((obj,)) From 5ea484bb85406d480821470fa446583fdca60709 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 11:42:59 +0200 Subject: [PATCH 42/50] also allow converting and stripping Variable objects --- pint_xarray/conversion.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pint_xarray/conversion.py b/pint_xarray/conversion.py index b18b8a63..c0860b26 100644 --- a/pint_xarray/conversion.py +++ b/pint_xarray/conversion.py @@ -223,7 +223,9 @@ def convert_units_dataset(obj, units): def convert_units(obj, units): - if not isinstance(obj, (DataArray, Dataset)): + if isinstance(obj, Variable): + return convert_units_variable(obj, units) + elif not isinstance(obj, (DataArray, Dataset)): raise ValueError(f"cannot convert object: {obj!r}: unknown type") if isinstance(obj, DataArray): @@ -299,7 +301,9 @@ def strip_units_dataset(obj): def strip_units(obj): - if not isinstance(obj, (DataArray, Dataset)): + if isinstance(obj, Variable): + return strip_units_variable(obj) + elif not isinstance(obj, (DataArray, Dataset)): raise ValueError("cannot strip units from {obj!r}: unknown type") return call_on_dataset(strip_units_dataset, obj, name=temporary_name) From b7c71c1cbee51200e72e9caa31574b4315a0a595 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 11:44:55 +0200 Subject: [PATCH 43/50] implement the conversion functions --- pint_xarray/checking.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index aadc09f1..96f214a7 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -3,8 +3,9 @@ from inspect import Parameter from pint import Quantity -from xarray import DataArray, Dataset +from xarray import DataArray, Dataset, Variable +from . import conversion from .accessors import PintDataArrayAccessor # noqa @@ -22,12 +23,25 @@ def detect_missing_params(params, units): } +def convert_and_strip(obj, units): + if isinstance(obj, (DataArray, Dataset, Variable)): + if not isinstance(units, dict): + units = {None: units} + return conversion.strip_units(conversion.convert_units(obj, units)) + elif isinstance(obj, Quantity): + return obj.m_as(units) + elif units is None: + return obj + else: + raise ValueError(f"unknown type: {type(obj)}") + + def convert_and_strip_args(args, units): - pass + return [convert_and_strip(obj, units_) for obj, units_ in zip(args, units)] def convert_and_strip_kwargs(kwargs, units): - pass + return {name: convert_and_strip(kwargs[name], units[name]) for name in kwargs} def attach_return_units(results, units): From b39fff3c14766bfd261200eaf04cb99d90bf5dc6 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 11:46:10 +0200 Subject: [PATCH 44/50] simplify the return construct --- pint_xarray/checking.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 96f214a7..4a554d60 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -68,12 +68,10 @@ def attach_return_units(results, units): results_iterable, return_units_iterable ) - if isinstance(results, tuple): + if isinstance(results, tuple) or len(converted_results) != 1: return converted_results - elif len(converted_results) == 1: - return converted_results[0] else: - return converted_results + return converted_results[0] def _check_or_convert_to_then_strip(obj, units): From 61e02999b2f1527bc89de63ce652ae5719192dc4 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 11:51:33 +0200 Subject: [PATCH 45/50] code reorganization --- pint_xarray/checking.py | 46 ++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index 4a554d60..d507fd69 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -44,6 +44,29 @@ def convert_and_strip_kwargs(kwargs, units): return {name: convert_and_strip(kwargs[name], units[name]) for name in kwargs} +def always_iterable(obj, base_type=(str, bytes)): + """ + If *obj* is iterable, return an iterator over its items, + If *obj* is not iterable, return a one-item iterable containing *obj*, + If *obj* is ``None``, return an empty iterable. + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + Copied from more_itertools. + """ + + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + def attach_return_units(results, units): if units is None: # ignore types and units of return values @@ -110,29 +133,6 @@ def _attach_multiple_units(objects, units): return converted_objects -def always_iterable(obj, base_type=(str, bytes)): - """ - If *obj* is iterable, return an iterator over its items, - If *obj* is not iterable, return a one-item iterable containing *obj*, - If *obj* is ``None``, return an empty iterable. - If *base_type* is set, objects for which ``isinstance(obj, base_type)`` - returns ``True`` won't be considered iterable. - - Copied from more_itertools. - """ - - if obj is None: - return iter(()) - - if (base_type is not None) and isinstance(obj, base_type): - return iter((obj,)) - - try: - return iter(obj) - except TypeError: - return iter((obj,)) - - def expects(*args_units, return_units=None, **kwargs_units): """ Decorator which ensures the inputs and outputs of the decorated From 63d8aeb43042136fea0b5847c480bafe6631e2c0 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 11:53:59 +0200 Subject: [PATCH 46/50] black --- pint_xarray/tests/test_checking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 5eb7d891..9e43cb7d 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -83,13 +83,13 @@ def second_law(m, a): def test_multiple_return_values(self): @expects("kg", "m / s", return_units=["J", "newton seconds"]) def energy_and_momentum(m, v): - ke = 0.5 * m * v ** 2 + ke = 0.5 * m * v**2 p = m * v return ke, p m = pint.Quantity(0.1, units="tons") v = pint.Quantity(10, units="feet / second") - expected_ke = (0.5 * m * v ** 2).to("J") + expected_ke = (0.5 * m * v**2).to("J") expected_p = (m * v).to("newton seconds") result_ke, result_p = energy_and_momentum(m, v) assert result_ke.units == expected_ke.units @@ -97,7 +97,7 @@ def energy_and_momentum(m, v): m = xr.DataArray(0.1).pint.quantify(units="tons") v = xr.DataArray(10).pint.quantify(units="feet / second") - expected_ke = (0.5 * m * v ** 2).pint.to("J") + expected_ke = (0.5 * m * v**2).pint.to("J") expected_p = (m * v).pint.to("newton seconds") result_ke, result_p = energy_and_momentum(m, v) assert result_ke.pint.units == expected_ke.pint.units From 32a57b239fd3a30bc5372b18c1a95aeb76bfff11 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 11:54:06 +0200 Subject: [PATCH 47/50] fix a test --- pint_xarray/tests/test_checking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index 9e43cb7d..f881d04b 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -49,7 +49,7 @@ def freq(wavelength, c=None): assert f_da.pint.units == pint.Unit("hertz") def test_weighted_kwarg(self): - @expects(None, weights="dimensionless") + @expects(None, weights="dimensionless", return_units="metres") def mean(da, weights=None): if weights is not None: return da.weighted(weights=weights).mean() From 91b482605fa33d8506e350ca46a167c050589dbf Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 12:01:14 +0200 Subject: [PATCH 48/50] remove the note about coordinates not being checked [skip-ci] --- pint_xarray/checking.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index d507fd69..ffdc9d9d 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -145,12 +145,6 @@ def expects(*args_units, return_units=None, **kwargs_units): unquantified DataArrays, Datasets, or numpy-like arrays, but with the values expressed in specific units. - .. note:: - The coordinates of input DataArrays are not checked, only the - data. So if your decorated function uses coordinates and you - wish to check their units, you should pass the coordinates of - interest as separate arguments. - Parameters ---------- func : callable From a43dd13017653e6bd7672a699edad6fed57ac894 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 14:22:45 +0200 Subject: [PATCH 49/50] reword the error message raised when there's no units for some parameters --- pint_xarray/checking.py | 5 ++++- pint_xarray/tests/test_checking.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py index ffdc9d9d..cc7b7e23 100644 --- a/pint_xarray/checking.py +++ b/pint_xarray/checking.py @@ -232,7 +232,10 @@ def _expects_decorator(func): missing_params = detect_missing_params(params, bound_units) if missing_params: - raise ValueError(f"no units for {missing_params}") + raise TypeError( + "Some parameters of the decorated function are missing units:" + f" {', '.join(sorted(missing_params))}" + ) @functools.wraps(func) def _unit_checking_wrapper(*args, **kwargs): diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py index f881d04b..068729d2 100644 --- a/pint_xarray/tests/test_checking.py +++ b/pint_xarray/tests/test_checking.py @@ -141,7 +141,7 @@ def freq(period): def test_wrong_number_of_args(self): with pytest.raises( TypeError, - match="expects 1 arguments, but a function expecting 2 arguments was wrapped", + match="Some parameters of the decorated function are missing units", ): @expects("kg", return_units="newtons") From 7ea921c4c77a7452de829678eeca629567b2b61d Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 21 Sep 2022 14:40:57 +0200 Subject: [PATCH 50/50] move the changelog to a new section --- docs/whats-new.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/whats-new.rst b/docs/whats-new.rst index c8dc4e59..2dfae1e5 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -2,6 +2,12 @@ What's new ========== +0.4 (unreleased) +---------------- + +- Added the :py:func:`pint_xarray.expects` decorator (:pull:`143`). + By `Tom Nicholas `_ and `Justus Magin `_. + 0.3 (27 Jul 2022) ----------------- - drop support for python 3.7 (:pull:`153`) @@ -16,12 +22,6 @@ What's new as identity operators (:issue:`47`, :pull:`175`). By `Justus Magin `_. -0.2.2 (unreleased) ------------------- - -- Added the :py:func:`pint_xarray.expects` decorator (:pull:`143`). - By `Tom Nicholas `_ - 0.2.1 (26 Jul 2021) ------------------- - allow special "no unit" values in :py:meth:`Dataset.pint.quantify` and