Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expects decorator #143

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
044d59a
draft implementation of @expects
TomNicholas Oct 20, 2021
0754f22
sketch of different tests needed
TomNicholas Oct 20, 2021
e879ef9
idea for test
TomNicholas Oct 22, 2021
aad7936
upgrade check then convert function to optionally take magnitude
TomNicholas Oct 28, 2021
e354f4e
removed magnitude option
TomNicholas Nov 29, 2021
7727d8e
works for single return value
TomNicholas Nov 29, 2021
1379779
works for single kwarg
TomNicholas Nov 30, 2021
77f5d02
works for multiple return values
TomNicholas Nov 30, 2021
71f4200
allow passing through arguments unchecked
TomNicholas Nov 30, 2021
a710741
check types of units
TomNicholas Nov 30, 2021
497e97f
remove uneeded option to specify a lack of return value
TomNicholas Nov 30, 2021
00219bc
check number of inputs and return values
TomNicholas Nov 30, 2021
9e92f21
removed nonlocal keyword
TomNicholas Nov 30, 2021
86f7e58
generalised to handle specifying dicts of units
TomNicholas Nov 30, 2021
2141c6c
type hint for func
TomNicholas Nov 30, 2021
a94a6ae
type hint for args_units
TomNicholas Nov 30, 2021
a2cc63f
Merge branch 'expects_decorator' of https://github.com/TomNicholas/pi…
TomNicholas Nov 30, 2021
7103483
numpy-style type hints for all arguments
TomNicholas Nov 30, 2021
59ddf86
whats new
TomNicholas Nov 30, 2021
a5a2493
add to API docs
TomNicholas Nov 30, 2021
e5e84fb
use always_iterable
TomNicholas Dec 14, 2021
3f59414
hashable
TomNicholas Dec 14, 2021
b281674
hashable
TomNicholas Dec 14, 2021
c669105
hashable
TomNicholas Dec 14, 2021
9ac8887
dict comprehension
TomNicholas Dec 14, 2021
0a6447d
list comprehension
TomNicholas Dec 14, 2021
c29f935
unindent if/else
TomNicholas Dec 14, 2021
81913a6
missing parenthesis
TomNicholas Dec 14, 2021
4de6f4d
simplify if/else logic for checking there were actually results
TomNicholas Dec 14, 2021
83e422f
return results immediately if a tuple
TomNicholas Dec 14, 2021
37c3fbc
allow for returning Datasets from wrapped funciton
TomNicholas Dec 14, 2021
9c19af0
Update docs/api.rst
TomNicholas Jan 14, 2022
0b5c7c0
correct indentation of docstring
TomNicholas Jan 14, 2022
0f50305
use inspects to check number of arguments passed to decorated function
TomNicholas Jan 14, 2022
57d341e
reformat the docstring
keewis Jan 15, 2022
8845b77
update the definition of unit-like
keewis Jan 16, 2022
bc41425
simplify if/else statement
TomNicholas Jan 18, 2022
aba2d11
Merge branch 'expects_decorator' of https://github.com/TomNicholas/pi…
TomNicholas Jan 18, 2022
0350308
check units in .to instead
TomNicholas Jan 19, 2022
3a24a73
remove extra xfailed test
TomNicholas Jan 19, 2022
19fd6e0
test raises on unquantified input
TomNicholas Jan 20, 2022
d2d74e4
add example of function which optionally accepts dimensionless weights
TomNicholas Jan 20, 2022
1c4feb4
Merge branch 'main' into expects_decorator
keewis Mar 11, 2022
7a6f2cb
Merge branch 'main' into expects_decorator
keewis Sep 20, 2022
85b982c
rewrite using inspect.Signature's bind and bind_partial
keewis Sep 20, 2022
5ea484b
also allow converting and stripping Variable objects
keewis Sep 21, 2022
b7c71c1
implement the conversion functions
keewis Sep 21, 2022
b39fff3
simplify the return construct
keewis Sep 21, 2022
61e0299
code reorganization
keewis Sep 21, 2022
63d8aeb
black
keewis Sep 21, 2022
32a57b2
fix a test
keewis Sep 21, 2022
91b4826
remove the note about coordinates not being checked [skip-ci]
keewis Sep 21, 2022
a43dd13
reword the error message raised when there's no units for some parame…
keewis Sep 21, 2022
7ea921c
move the changelog to a new section
keewis Sep 21, 2022
b92087a
Merge branch 'main' into expects_decorator
keewis Sep 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ DataArray
xarray.DataArray.pint.bfill
xarray.DataArray.pint.interpolate_na

Checking
--------

.. autosummary::
:toctree: generated/

pint_xarray.expects

Testing
-------

Expand Down
3 changes: 2 additions & 1 deletion docs/terminology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Terminology

unit-like
A `pint`_ unit definition, as accepted by :py:class:`pint.Unit`.
May be either a :py:class:`str` or a :py:class:`pint.Unit` instance.
May be either a :py:class:`str`, a :py:class:`pint.Unit`
instance, or :py:obj:`None`.

.. _pint: https://pint.readthedocs.io/en/stable
2 changes: 2 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ What's new
0.4 (*unreleased*)
------------------

- Added the :py:func:`pint_xarray.expects` decorator (:pull:`143`).
By `Tom Nicholas <https://github.com/TomNicholas>`_ and `Justus Magin <https://github.com/keewis>`_.

0.3 (27 Jul 2022)
-----------------
Expand Down
1 change: 1 addition & 0 deletions pint_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import accessors, formatting, testing # noqa: F401
from .accessors import default_registry as unit_registry
from .accessors import setup_registry
from .checking import expects # noqa: F401

try:
__version__ = version("pint-xarray")
Expand Down
255 changes: 255 additions & 0 deletions pint_xarray/checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import functools
import inspect
from inspect import Parameter

from pint import Quantity
from xarray import DataArray, Dataset, Variable

from . import conversion
from .accessors import PintDataArrayAccessor # noqa


def detect_missing_params(params, units):
"""detect parameters for which no units were specified"""
variable_params = {
Parameter.VAR_POSITIONAL,
Parameter.VAR_KEYWORD,
}

return {
name
for name, param in params.items()
if name not in units.arguments and param.kind not in variable_params
}


def convert_and_strip(obj, units):
if isinstance(obj, (DataArray, Dataset, Variable)):
if not isinstance(units, dict):
units = {None: units}
return conversion.strip_units(conversion.convert_units(obj, units))
elif isinstance(obj, Quantity):
return obj.m_as(units)
elif units is None:
return obj
else:
raise ValueError(f"unknown type: {type(obj)}")


def convert_and_strip_args(args, units):
return [convert_and_strip(obj, units_) for obj, units_ in zip(args, units)]


def convert_and_strip_kwargs(kwargs, units):
return {name: convert_and_strip(kwargs[name], units[name]) for name in kwargs}


def always_iterable(obj, base_type=(str, bytes)):
"""
If *obj* is iterable, return an iterator over its items,
If *obj* is not iterable, return a one-item iterable containing *obj*,
If *obj* is ``None``, return an empty iterable.
If *base_type* is set, objects for which ``isinstance(obj, base_type)``
returns ``True`` won't be considered iterable.

Copied from more_itertools.
"""

if obj is None:
return iter(())

if (base_type is not None) and isinstance(obj, base_type):
return iter((obj,))

try:
return iter(obj)
except TypeError:
return iter((obj,))


def attach_return_units(results, units):
if units is None:
# ignore types and units of return values
return results
elif results is None:
raise TypeError(
"Expected function to return something, but function returned None"
)
else:
# handle case of function returning only one result by promoting to 1-element tuple
return_units_iterable = tuple(always_iterable(units, base_type=(str, dict)))
results_iterable = tuple(always_iterable(results, base_type=(str, Dataset)))

# check same number of things were returned as expected
if len(results_iterable) != len(return_units_iterable):
raise TypeError(
f"{len(results_iterable)} return values were received, but {len(return_units_iterable)} "
"return values were expected"
)

converted_results = _attach_multiple_units(
results_iterable, return_units_iterable
)

if isinstance(results, tuple) or len(converted_results) != 1:
return converted_results
else:
return converted_results[0]


def _check_or_convert_to_then_strip(obj, units):
"""
Checks the object is of a valid type (Quantity or DataArray), then attempts to convert it to the specified units,
then strips the units from it.
"""

if units is None:
# allow for passing through non-numerical arguments
return obj
elif isinstance(obj, Quantity):
converted = obj.to(units)
return converted.magnitude
elif isinstance(obj, (DataArray, Dataset)):
converted = obj.pint.to(units)
return converted.pint.dequantify()
else:
raise TypeError(
"Can only expect units for arguments of type xarray.DataArray,"
f" xarray.Dataset, or pint.Quantity, not {type(obj)}"
)


def _attach_units(obj, units):
"""Attaches units, but can also create pint.Quantity objects from numpy scalars"""
if isinstance(obj, (DataArray, Dataset)):
return obj.pint.quantify(units)
else:
return Quantity(obj, units=units)


def _attach_multiple_units(objects, units):
"""Attaches list of units to list of objects elementwise"""
converted_objects = [_attach_units(obj, unit) for obj, unit in zip(objects, units)]
return converted_objects


def expects(*args_units, return_units=None, **kwargs_units):
"""
Decorator which ensures the inputs and outputs of the decorated
function are expressed in the expected units.

Arguments to the decorated function are checked for the specified
units, converting to those units if necessary, and then stripped
of their units before being passed into the undecorated
function. Therefore the undecorated function should expect
unquantified DataArrays, Datasets, or numpy-like arrays, but with
the values expressed in specific units.

Parameters
----------
func : callable
Function to decorate, which accepts zero or more
xarray.DataArrays or numpy-like arrays as inputs, and may
optionally return one or more xarray.DataArrays or numpy-like
arrays.
*args_units : unit-like or mapping of hashable to unit-like, optional
Units to expect for each positional argument given to func.

The decorator will first check that arguments passed to the
decorated function possess these specific units (or will
attempt to convert the argument to these units), then will
strip the units before passing the magnitude to the wrapped
function.

A value of None indicates not to check that argument for units
(suitable for flags and other non-data arguments).
keewis marked this conversation as resolved.
Show resolved Hide resolved
return_units : unit-like or list of unit-like or mapping of hashable to unit-like \
or list of mapping of hashable to unit-like, optional
The expected units of the returned value(s), either as a
single unit or as a list of units. The decorator will attach
these units to the variables returned from the function.

A value of None indicates not to attach any units to that
return value (suitable for flags and other non-data results).
keewis marked this conversation as resolved.
Show resolved Hide resolved
kwargs_units : mapping of hashable to unit-like, optional
Unit to expect for each keyword argument given to func.

The decorator will first check that arguments passed to the
decorated function possess these specific units (or will
attempt to convert the argument to these units), then will
strip the units before passing the magnitude to the wrapped
function.

A value of None indicates not to check that argument for units
(suitable for flags and other non-data arguments).
keewis marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
return_values : Any
Return values of the wrapped function, either a single value
or a tuple of values. These will be given units according to
return_units.

Raises
------
TypeError
If an argument or return value has a specified unit, but is
not an xarray.DataArray or pint.Quantity. Also thrown if any
of the units are not a valid type, or if the number of
arguments or return values does not match the number of units
specified.
Comment on lines +196 to +200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we raise ValueError if the number of arguments / return values do not match?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah possibly. I can change that.


Examples
--------

Decorating a function which takes one quantified input, but
returns a non-data value (in this case a boolean).

>>> @expects("deg C")
... def above_freezing(temp):
... return temp > 0

Decorating a function which allows any dimensions for the array, but also
accepts an optional `weights` keyword argument, which must be dimensionless.

>>> @expects(None, weights="dimensionless")
... def mean(da, weights=None):
... if weights:
... return da.weighted(weights=weights).mean()
... else:
... return da.mean()

"""

def _expects_decorator(func):

# check same number of arguments were passed as expected
sig = inspect.signature(func)

params = sig.parameters

bound_units = sig.bind_partial(*args_units, **kwargs_units)

missing_params = detect_missing_params(params, bound_units)
if missing_params:
raise TypeError(
"Some parameters of the decorated function are missing units:"
f" {', '.join(sorted(missing_params))}"
)

@functools.wraps(func)
def _unit_checking_wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)

converted_args = convert_and_strip_args(bound.args, bound_units.args)
converted_kwargs = convert_and_strip_kwargs(
bound.kwargs, bound_units.kwargs
)

results = func(*converted_args, **converted_kwargs)

return attach_return_units(results, return_units)

return _unit_checking_wrapper

return _expects_decorator
8 changes: 6 additions & 2 deletions pint_xarray/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def convert_units_dataset(obj, units):


def convert_units(obj, units):
if not isinstance(obj, (DataArray, Dataset)):
if isinstance(obj, Variable):
return convert_units_variable(obj, units)
elif not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"cannot convert object: {obj!r}: unknown type")

if isinstance(obj, DataArray):
Expand Down Expand Up @@ -299,7 +301,9 @@ def strip_units_dataset(obj):


def strip_units(obj):
if not isinstance(obj, (DataArray, Dataset)):
if isinstance(obj, Variable):
return strip_units_variable(obj)
elif not isinstance(obj, (DataArray, Dataset)):
raise ValueError("cannot strip units from {obj!r}: unknown type")

return call_on_dataset(strip_units_dataset, obj, name=temporary_name)
Expand Down
Loading