Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask: Field.where #464

Merged
merged 3 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 26 additions & 105 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8202,10 +8202,12 @@ def outerproduct(self, a, inplace=False, i=False):
d = _inplace_enabled_define_and_cleanup(self)

# Cast 'a' as a Data object so that it definitely has sensible
# Units
# Units. We don't mind if the units of 'a' are incompatible
# with those of 'self', but if they are then it's nice if the
# units are conformed.
a = self.asdata(a)
try:
a = conform_units(a, d.Units)
a = conform_units(a, d.Units, message="")
except ValueError:
pass

Expand Down Expand Up @@ -9562,8 +9564,11 @@ def where(
**Broadcasting**

The array and the *condition*, *x* and *y* parameters must all
be broadcastable to each other, such that the shape of the
result is identical to the orginal shape of the array.
be broadcastable across the original array, such that the size
of the result is identical to the orginal size of the
davidhassell marked this conversation as resolved.
Show resolved Hide resolved
array. Leading size 1 dimensions of these parameters are
ignored, thereby also ensuring that the shape of the result is
identical to the orginal shape of the array.

If *condition* is a `Query` object then for the purposes of
broadcasting, the condition is considered to be that which is
Expand All @@ -9581,7 +9586,7 @@ def where(

:Parameters:

condition: array-like or `Query`
condition: array_like or `Query`
The condition which determines how to assign values to
the data.

Expand Down Expand Up @@ -9714,7 +9719,7 @@ def where(
>>> d = cf.Data(x)
>>> e = d.where(condition, d, 10 + y)
...
ValueError: where: Broadcasting the 'condition' parameter with shape (3, 4) would change the shape of the data with shape (3, 1)
ValueError: where: 'condition' parameter with shape (3, 4) can not be broadcast across data with shape (3, 1) when the result will have a different shape to the data

>>> d = cf.Data(np.arange(9).reshape(3, 3))
>>> e = d.copy()
Expand All @@ -9732,6 +9737,8 @@ def where(
[ 6. 7. 8. ]]

"""
from .utils import where_broadcastable

d = _inplace_enabled_define_and_cleanup(self)

# Missing values could be affected, so make sure that the mask
Expand All @@ -9744,13 +9751,13 @@ def where(
if getattr(condition, "isquery", False):
# Condition is a cf.Query object: Make sure that the
# condition units are OK, and convert the condition to a
# boolean dask array with the same shape as the data.
# boolean Data instance with the same shape as the data.
condition = condition.copy()
condition = condition.set_condition_units(units)
condition.set_condition_units(units)
condition = condition.evaluate(d)

condition = type(self).asdata(condition)
_where_broadcastable(d, condition, "condition")
condition = where_broadcastable(d, condition, "condition")

# If x or y is self then change it to None. This prevents an
# unnecessary copy; and, at compute time, an unncessary numpy
Expand Down Expand Up @@ -9778,18 +9785,16 @@ def where(
continue

arg = type(self).asdata(arg)
_where_broadcastable(d, arg, name)

if arg.Units:
# Make sure that units are OK.
arg = arg.copy()
try:
arg.Units = units
except ValueError:
raise ValueError(
f"where: {name!r} parameter units {arg.Units!r} "
f"are not equivalent to data units {units!r}"
)
arg = where_broadcastable(d, arg, name)

arg_units = arg.Units
if arg_units:
arg = conform_units(
arg,
units,
message=f"where: {name!r} parameter units {arg_units!r} "
f"are not equivalent to data units {units!r}",
)

xy.append(arg.to_dask_array())

Expand Down Expand Up @@ -11774,90 +11779,6 @@ def _size_of_index(index, size=None):
return len(index)


def _broadcast(a, shape):
"""Broadcast an array to a given shape.

It is assumed that ``len(array.shape) <= len(shape)`` and that the
array is broadcastable to the shape by the normal numpy
boradcasting rules, but neither of these things are checked.

For example, ``d[...] = d._broadcast(e, d.shape)`` gives the same
result as ``d[...] = e``

:Parameters:

a: numpy array-like

shape: `tuple`

:Returns:

`numpy.ndarray`

"""
# Replace with numpy.broadcast_to v1.10 ??/ TODO

a_shape = np.shape(a)
if a_shape == shape:
return a

tile = [(m if n == 1 else 1) for n, m in zip(a_shape[::-1], shape[::-1])]
tile = shape[0 : len(shape) - len(a_shape)] + tuple(tile[::-1])

return np.tile(a, tile)


def _where_broadcastable(data, x, name):
"""Check broadcastability for `where` assignments.

Raises an exception if the result of broadcasting *data* and *x*
together does not have the same shape as *data*.

.. versionadded:: TODODASKVER

.. seealso:: `where`

:Parameters:

data, x: `Data`
The arrays to compare.

name: `str`
A name for *x* that is used in any exception error
message.

:Returns:

`bool`
If *x* is acceptably broadcastable to *data* then `True`
is returned, otherwise a `ValueError` is raised.

"""
ndim_x = x.ndim
if not ndim_x:
return True

ndim_data = data.ndim
if ndim_x > ndim_data:
raise ValueError(
f"where: Broadcasting the {name!r} parameter with {ndim_x} "
f"dimensions would change the shape of the data with "
f"{ndim_data} dimensions"
)

shape_x = x.shape
shape_data = data.shape
for n, m in zip(shape_x[::-1], shape_data[::-1]):
if n != m and n != 1:
raise ValueError(
f"where: Broadcasting the {name!r} parameter with shape "
f"{shape_x} would change the shape of the data with shape "
f"{shape_data}"
)

return True


def _collapse(
func,
d,
Expand Down
96 changes: 89 additions & 7 deletions cf/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,18 @@ def scalar_masked_array(dtype=float):
return a


def conform_units(value, units):
def conform_units(value, units, message=None):
"""Conform units.

If *value* has units defined by its `Units` attribute then

* if the value units are equal to *units* then *value* is returned
* If the value units are equal to *units* then *value* is returned
unchanged;

* if the value units are equivalent to *units* then a copy of
* If the value units are equivalent to *units* then a copy of
*value* converted to *units* is returned;

* if the value units are not equivalent to *units* then an
* If the value units are not equivalent to *units* then an
exception is raised.

In all other cases *value* is returned unchanged.
Expand All @@ -579,6 +579,12 @@ def conform_units(value, units):
units: `Units`
The units to conform to.

message: `str`, optional
If the value units are not equivalent to *units* then use
this message when the exception is raised. By default a
message that is independent of the calling context is
used.

**Examples**

>>> cf.data.utils.conform_units(1, cf.Units('m'))
Expand All @@ -600,6 +606,10 @@ def conform_units(value, units):
Traceback (most recent call last):
...
ValueError: Units <Units: km> are incompatible with units <Units: s>
>>> cf.data.utils.conform_units(d, cf.Units('s'), message='My message')
Traceback (most recent call last):
...
ValueError: My message

"""
value_units = getattr(value, "Units", None)
Expand All @@ -611,9 +621,12 @@ def conform_units(value, units):
value = value.copy()
value.Units = units
elif value_units and units:
raise ValueError(
f"Units {value_units!r} are incompatible with units {units!r}"
)
if message is None:
message = (
f"Units {value_units!r} are incompatible with units {units!r}"
)

raise ValueError(message)

return value

Expand Down Expand Up @@ -661,3 +674,72 @@ def YMDhms(d, attr):
d._set_dask(dx)
d.override_units(Units(None), inplace=True)
return d


def where_broadcastable(data, x, name):
"""Check broadcastability for `cf.Data.where` assignments.

Raises an exception unless the *data* and *x* parameters are
broadcastable across each other, such that the size of the result
is identical to the size of *data*. Leading size 1 dimensions of
*x* are ignored, thereby also ensuring that the shape of the
result is identical to the shape of *data*.

.. versionadded:: TODODASKVER

.. seealso:: `cf.Data.where`

:Parameters:

data, x: `Data`
The arrays to compare.

name: `str`
A name for *x* that is used in exception error messages.

:Returns:

`Data`
The input parameter *x*, or a modified copy without
leading size 1 dimensions. If *x* can not be acceptably
broadcast to *data* then a `ValueError` is raised.

"""
ndim_x = x.ndim
if not ndim_x:
return x

error = 0

shape_x = x.shape
shape_data = data.shape

shape_x0 = shape_x
ndim_difference = ndim_x - data.ndim

if ndim_difference > 0:
if shape_x[:ndim_difference] == (1,) * ndim_difference:
# Remove leading ize 1 dimensions
x = x.reshape(shape_x[ndim_difference:])
shape_x = x.shape
else:
error += 1

for n, m in zip(shape_x[::-1], shape_data[::-1]):
if n != m and m > 1 and n > 1:
raise ValueError(
f"where: {name!r} parameter with shape {shape_x0} can not "
f"be broadcast across data with shape {shape_data}"
)

if m == 1 and n > 1:
error += 1

if error:
raise ValueError(
f"where: {name!r} parameter with shape {shape_x0} can not "
f"be broadcast across data with shape {shape_data} when the "
"result will have a different shape to the data"
)

return x
Loading