Skip to content

Commit 52ee95f

Browse files
committed
Add broadcast function to the API
This is a renaming and update of the existing `xray.broadcast_arrays` function, which now works properly in the light of GH648. Examples -------- Broadcast two data arrays against one another to fill out their dimensions: >>> a = xray.DataArray([1, 2, 3], dims='x') >>> b = xray.DataArray([5, 6], dims='y') >>> a <xray.DataArray (x: 3)> array([1, 2, 3]) Coordinates: * x (x) int64 0 1 2 >>> b <xray.DataArray (y: 2)> array([5, 6]) Coordinates: * y (y) int64 0 1 >>> a2, b2 = xray.broadcast(a, b) >>> a2 <xray.DataArray (x: 3, y: 2)> array([[1, 1], [2, 2], [3, 3]]) Coordinates: * x (x) int64 0 1 2 * y (y) int64 0 1 >>> b2 <xray.DataArray (x: 3, y: 2)> array([[5, 6], [5, 6], [5, 6]]) Coordinates: * y (y) int64 0 1 * x (x) int64 0 1 2 Fill out the dimensions of all data variables in a dataset: >>> ds = xray.Dataset({'a': a, 'b': b}) >>> ds2, = xray.broadcast(ds) # use tuple unpacking to extract one dataset >>> ds2 <xray.Dataset> Dimensions: (x: 3, y: 2) Coordinates: * x (x) int64 0 1 2 * y (y) int64 0 1 Data variables: a (x, y) int64 1 1 2 2 3 3 b (x, y) int64 5 6 5 6 5 6
1 parent e1097cf commit 52ee95f

File tree

6 files changed

+155
-26
lines changed

6 files changed

+155
-26
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Top-level functions
1515
:toctree: generated/
1616

1717
align
18+
broadcast
1819
concat
1920
set_options
2021

doc/whats-new.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Bug fixes
6767
- Fixes for several issues found on ``DataArray`` objects with the same name
6868
as one of their coordinates (see :ref:`v0.7.0.breaking` for more details).
6969

70-
- ``DataArray.to_masked_array`` always returns masked array with mask being an array
70+
- ``DataArray.to_masked_array`` always returns masked array with mask being an array
7171
(not a scalar value) (:issue:`684`)
7272

7373
v0.6.2 (unreleased)
@@ -96,6 +96,18 @@ Enhancements
9696
moves both data and coordinates.
9797
- Assigning a ``pandas`` object to a ``Dataset`` directly is now permitted. Its
9898
index names correspond to the `dims`` of the ``Dataset``, and its data is aligned
99+
- New function :py:func:`~xray.broadcast` for explicitly broadcasting
100+
``DataArray`` and ``Dataset`` objects against each other. For example:
101+
102+
.. ipython:: python
103+
104+
a = xray.DataArray([1, 2, 3], dims='x')
105+
b = xray.DataArray([5, 6], dims='y')
106+
a
107+
b
108+
a2, b2 = xray.broadcast(a, b)
109+
a2
110+
b2
99111
100112
Bug fixes
101113
~~~~~~~~~

xray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .core.alignment import align, broadcast_arrays
1+
from .core.alignment import align, broadcast, broadcast_arrays
22
from .core.combine import concat, auto_combine
33
from .core.variable import Variable, Coordinate
44
from .core.dataset import Dataset

xray/core/alignment.py

Lines changed: 97 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -253,46 +253,126 @@ def var_indexers(var, indexers):
253253
return reindexed
254254

255255

256-
def broadcast_arrays(*args):
257-
"""Explicitly broadcast any number of DataArrays against one another.
256+
def broadcast(*args):
257+
"""Explicitly broadcast any number of DataArray or Dataset objects against
258+
one another.
258259
259260
xray objects automatically broadcast against each other in arithmetic
260261
operations, so this function should not be necessary for normal use.
261262
262263
Parameters
263264
----------
264-
*args: DataArray
265+
*args: DataArray or Dataset objects
265266
Arrays to broadcast against each other.
266267
267268
Returns
268269
-------
269-
broadcast: tuple of DataArray
270+
broadcast: tuple of xray objects
270271
The same data as the input arrays, but with additional dimensions
271-
inserted so that all arrays have the same dimensions and shape.
272+
inserted so that all data arrays have the same dimensions and shape.
272273
273274
Raises
274275
------
275276
ValueError
276-
If indexes on the different arrays are not aligned.
277+
If indexes on the different objects are not aligned.
278+
279+
Examples
280+
--------
281+
282+
Broadcast two data arrays against one another to fill out their dimensions:
283+
284+
>>> a = xray.DataArray([1, 2, 3], dims='x')
285+
>>> b = xray.DataArray([5, 6], dims='y')
286+
>>> a
287+
<xray.DataArray (x: 3)>
288+
array([1, 2, 3])
289+
Coordinates:
290+
* x (x) int64 0 1 2
291+
>>> b
292+
<xray.DataArray (y: 2)>
293+
array([5, 6])
294+
Coordinates:
295+
* y (y) int64 0 1
296+
>>> a2, b2 = xray.broadcast(a, b)
297+
>>> a2
298+
<xray.DataArray (x: 3, y: 2)>
299+
array([[1, 1],
300+
[2, 2],
301+
[3, 3]])
302+
Coordinates:
303+
* x (x) int64 0 1 2
304+
* y (y) int64 0 1
305+
>>> b2
306+
<xray.DataArray (x: 3, y: 2)>
307+
array([[5, 6],
308+
[5, 6],
309+
[5, 6]])
310+
Coordinates:
311+
* y (y) int64 0 1
312+
* x (x) int64 0 1 2
313+
314+
Fill out the dimensions of all data variables in a dataset:
315+
316+
>>> ds = xray.Dataset({'a': a, 'b': b})
317+
>>> ds2, = xray.broadcast(ds) # use tuple unpacking to extract one dataset
318+
>>> ds2
319+
<xray.Dataset>
320+
Dimensions: (x: 3, y: 2)
321+
Coordinates:
322+
* x (x) int64 0 1 2
323+
* y (y) int64 0 1
324+
Data variables:
325+
a (x, y) int64 1 1 2 2 3 3
326+
b (x, y) int64 5 6 5 6 5 6
277327
"""
278-
# TODO: fixme for coordinate arrays
279-
280328
from .dataarray import DataArray
329+
from .dataset import Dataset
281330

282331
all_indexes = _get_all_indexes(args)
283332
for k, v in all_indexes.items():
284333
if not all(v[0].equals(vi) for vi in v[1:]):
285334
raise ValueError('cannot broadcast arrays: the %s index is not '
286335
'aligned (use xray.align first)' % k)
287336

288-
vars = broadcast_variables(*[a.variable for a in args])
289-
indexes = dict((k, all_indexes[k][0]) for k in vars[0].dims)
337+
common_coords = OrderedDict()
338+
dims_map = OrderedDict()
339+
for arg in args:
340+
for dim in arg.dims:
341+
if dim not in common_coords:
342+
common_coords[dim] = arg.coords[dim].variable
343+
dims_map[dim] = common_coords[dim].size
344+
345+
def _broadcast_array(array):
346+
data = array.variable.expand_dims(dims_map)
347+
coords = OrderedDict(array.coords)
348+
coords.update(common_coords)
349+
dims = tuple(common_coords)
350+
return DataArray(data, coords, dims, name=array.name,
351+
attrs=array.attrs, encoding=array.encoding)
352+
353+
def _broadcast_dataset(ds):
354+
data_vars = OrderedDict()
355+
for k in ds.data_vars:
356+
data_vars[k] = ds.variables[k].expand_dims(dims_map)
357+
358+
coords = OrderedDict(ds.coords)
359+
coords.update(common_coords)
360+
361+
return Dataset(data_vars, coords, ds.attrs)
362+
363+
result = []
364+
for arg in args:
365+
if isinstance(arg, DataArray):
366+
result.append(_broadcast_array(arg))
367+
elif isinstance(arg, Dataset):
368+
result.append(_broadcast_dataset(arg))
369+
else:
370+
raise ValueError('all input must be Dataset or DataArray objects')
290371

291-
arrays = []
292-
for a, v in zip(args, vars):
293-
arr = DataArray(v.values, indexes, v.dims, a.name, a.attrs, a.encoding)
294-
for k, v in a.coords.items():
295-
arr.coords[k] = v
296-
arrays.append(arr)
372+
return tuple(result)
297373

298-
return tuple(arrays)
374+
375+
def broadcast_arrays(*args):
376+
warnings.warn('xray.broadcast_arrays is deprecated: use xray.broadcast '
377+
'instead', DeprecationWarning, stacklevel=2)
378+
return broadcast(*args)

xray/test/test_dataarray.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from copy import deepcopy
55
from textwrap import dedent
66

7-
from xray import (align, broadcast_arrays, Dataset, DataArray,
7+
from xray import (align, broadcast, Dataset, DataArray,
88
Coordinate, Variable)
99
from xray.core.pycompat import iteritems, OrderedDict
1010
from . import TestCase, ReturnItem, source_ndarray, unittest, requires_dask
@@ -1267,7 +1267,7 @@ def test_align_dtype(self):
12671267
def test_broadcast_arrays(self):
12681268
x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x')
12691269
y = DataArray([1, 2], coords=[('b', [3, 4])], name='y')
1270-
x2, y2 = broadcast_arrays(x, y)
1270+
x2, y2 = broadcast(x, y)
12711271
expected_coords = [('a', [-1, -2]), ('b', [3, 4])]
12721272
expected_x2 = DataArray([[1, 1], [2, 2]], expected_coords, name='x')
12731273
expected_y2 = DataArray([[1, 2], [1, 2]], expected_coords, name='y')
@@ -1276,15 +1276,27 @@ def test_broadcast_arrays(self):
12761276

12771277
x = DataArray(np.random.randn(2, 3), dims=['a', 'b'])
12781278
y = DataArray(np.random.randn(3, 2), dims=['b', 'a'])
1279-
x2, y2 = broadcast_arrays(x, y)
1279+
x2, y2 = broadcast(x, y)
12801280
expected_x2 = x
12811281
expected_y2 = y.T
12821282
self.assertDataArrayIdentical(expected_x2, x2)
12831283
self.assertDataArrayIdentical(expected_y2, y2)
12841284

1285+
z = DataArray([1, 2], coords=[('a', [-10, 20])])
12851286
with self.assertRaisesRegexp(ValueError, 'cannot broadcast'):
1286-
z = DataArray([1, 2], coords=[('a', [-10, 20])])
1287-
broadcast_arrays(x, z)
1287+
broadcast(x, z)
1288+
1289+
def test_broadcast_coordinates(self):
1290+
# regression test for GH649
1291+
ds = Dataset({'a': (['x', 'y'], np.ones((5, 6)))})
1292+
x_bc, y_bc, a_bc = broadcast(ds.x, ds.y, ds.a)
1293+
self.assertDataArrayIdentical(ds.a, a_bc)
1294+
1295+
X, Y = np.meshgrid(np.arange(5), np.arange(6), indexing='ij')
1296+
exp_x = DataArray(X, dims=['x', 'y'], name='x')
1297+
exp_y = DataArray(Y, dims=['x', 'y'], name='y')
1298+
self.assertDataArrayIdentical(exp_x, x_bc)
1299+
self.assertDataArrayIdentical(exp_y, y_bc)
12881300

12891301
def test_to_pandas(self):
12901302
# 0d

xray/test/test_dataset.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import numpy as np
1313
import pandas as pd
1414

15-
from xray import (align, concat, conventions, backends, Dataset, DataArray,
16-
Variable, Coordinate, auto_combine, open_dataset,
15+
from xray import (align, broadcast, concat, conventions, backends, Dataset,
16+
DataArray, Variable, Coordinate, auto_combine, open_dataset,
1717
set_options)
1818
from xray.core import indexing, utils
1919
from xray.core.pycompat import iteritems, OrderedDict
@@ -953,6 +953,30 @@ def test_align(self):
953953
with self.assertRaises(TypeError):
954954
align(left, right, foo='bar')
955955

956+
def test_broadcast(self):
957+
ds = Dataset({'foo': 0, 'bar': ('x', [1]), 'baz': ('y', [2, 3])},
958+
{'c': ('x', [4])})
959+
expected = Dataset({'foo': (('x', 'y'), [[0, 0]]),
960+
'bar': (('x', 'y'), [[1, 1]]),
961+
'baz': (('x', 'y'), [[2, 3]])},
962+
{'c': ('x', [4])})
963+
actual, = broadcast(ds)
964+
self.assertDatasetIdentical(expected, actual)
965+
966+
ds_x = Dataset({'foo': ('x', [1])})
967+
ds_y = Dataset({'bar': ('y', [2, 3])})
968+
expected_x = Dataset({'foo': (('x', 'y'), [[1, 1]])})
969+
expected_y = Dataset({'bar': (('x', 'y'), [[2, 3]])})
970+
actual_x, actual_y = broadcast(ds_x, ds_y)
971+
self.assertDatasetIdentical(expected_x, actual_x)
972+
self.assertDatasetIdentical(expected_y, actual_y)
973+
974+
array_y = ds_y['bar']
975+
expected_y = expected_y['bar']
976+
actual_x, actual_y = broadcast(ds_x, array_y)
977+
self.assertDatasetIdentical(expected_x, actual_x)
978+
self.assertDataArrayIdentical(expected_y, actual_y)
979+
956980
def test_variable_indexing(self):
957981
data = create_test_data()
958982
v = data['var1']

0 commit comments

Comments
 (0)