Skip to content

Commit 16c2895

Browse files
committed
Add broadcast function to the API
This is a renaming and update of the existing `xray.broadcast_arrays` function, which now works properly in the light of GH648. Examples -------- Broadcast two data arrays against one another to fill out their dimensions: >>> a = xray.DataArray([1, 2, 3], dims='x') >>> b = xray.DataArray([5, 6], dims='y') >>> a <xray.DataArray (x: 3)> array([1, 2, 3]) Coordinates: * x (x) int64 0 1 2 >>> b <xray.DataArray (y: 2)> array([5, 6]) Coordinates: * y (y) int64 0 1 >>> a2, b2 = xray.broadcast(a, b) >>> a2 <xray.DataArray (x: 3, y: 2)> array([[1, 1], [2, 2], [3, 3]]) Coordinates: * x (x) int64 0 1 2 * y (y) int64 0 1 >>> b2 <xray.DataArray (x: 3, y: 2)> array([[5, 6], [5, 6], [5, 6]]) Coordinates: * y (y) int64 0 1 * x (x) int64 0 1 2 Fill out the dimensions of all data variables in a dataset: >>> ds = xray.Dataset({'a': a, 'b': b}) >>> ds2, = xray.broadcast(ds) # use tuple unpacking to extract one dataset >>> ds2 <xray.Dataset> Dimensions: (x: 3, y: 2) Coordinates: * x (x) int64 0 1 2 * y (y) int64 0 1 Data variables: a (x, y) int64 1 1 2 2 3 3 b (x, y) int64 5 6 5 6 5 6
1 parent b7b8fae commit 16c2895

File tree

6 files changed

+144
-25
lines changed

6 files changed

+144
-25
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Top-level functions
1515
:toctree: generated/
1616

1717
align
18+
broadcast
1819
concat
1920
set_options
2021

doc/whats-new.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,18 @@ Enhancements
9191
9292
Notice that ``shift`` moves data independently of coordinates, but ``roll``
9393
moves both data and coordinates.
94+
- New function :py:func:`~xray.broadcast` for explicitly broadcasting
95+
``DataArray`` and ``Dataset`` objects against each other. For example:
96+
97+
.. ipython:: python
98+
99+
a = xray.DataArray([1, 2, 3], dims='x')
100+
b = xray.DataArray([5, 6], dims='y')
101+
a
102+
b
103+
a2, b2 = xray.broadcast(a, b)
104+
a2
105+
b2
94106
95107
Bug fixes
96108
~~~~~~~~~

xray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .core.alignment import align, broadcast_arrays
1+
from .core.alignment import align, broadcast, broadcast_arrays
22
from .core.combine import concat, auto_combine
33
from .core.variable import Variable, Coordinate
44
from .core.dataset import Dataset

xray/core/alignment.py

Lines changed: 99 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -244,46 +244,128 @@ def var_indexers(var, indexers):
244244
return reindexed
245245

246246

247-
def broadcast_arrays(*args):
248-
"""Explicitly broadcast any number of DataArrays against one another.
247+
def broadcast(*args):
248+
"""Explicitly broadcast any number of DataArray or Dataset objects against
249+
one another.
249250
250251
xray objects automatically broadcast against each other in arithmetic
251252
operations, so this function should not be necessary for normal use.
252253
253254
Parameters
254255
----------
255-
*args: DataArray
256+
*args: DataArray or Dataset objects
256257
Arrays to broadcast against each other.
257258
258259
Returns
259260
-------
260-
broadcast: tuple of DataArray
261+
broadcast: tuple of xray objects
261262
The same data as the input arrays, but with additional dimensions
262-
inserted so that all arrays have the same dimensions and shape.
263+
inserted so that all data arrays have the same dimensions and shape.
263264
264265
Raises
265266
------
266267
ValueError
267-
If indexes on the different arrays are not aligned.
268+
If indexes on the different objects are not aligned.
269+
270+
Examples
271+
--------
272+
273+
Broadcast two data arrays against one another to fill out their dimensions:
274+
275+
>>> a = xray.DataArray([1, 2, 3], dims='x')
276+
>>> b = xray.DataArray([5, 6], dims='y')
277+
>>> a
278+
<xray.DataArray (x: 3)>
279+
array([1, 2, 3])
280+
Coordinates:
281+
* x (x) int64 0 1 2
282+
>>> b
283+
<xray.DataArray (y: 2)>
284+
array([5, 6])
285+
Coordinates:
286+
* y (y) int64 0 1
287+
>>> a2, b2 = xray.broadcast(a, b)
288+
>>> a2
289+
<xray.DataArray (x: 3, y: 2)>
290+
array([[1, 1],
291+
[2, 2],
292+
[3, 3]])
293+
Coordinates:
294+
* x (x) int64 0 1 2
295+
* y (y) int64 0 1
296+
>>> b2
297+
<xray.DataArray (x: 3, y: 2)>
298+
array([[5, 6],
299+
[5, 6],
300+
[5, 6]])
301+
Coordinates:
302+
* y (y) int64 0 1
303+
* x (x) int64 0 1 2
304+
305+
Fill out the dimensions of all data variables in a dataset:
306+
307+
>>> ds = xray.Dataset({'a': a, 'b': b})
308+
>>> ds2, = xray.broadcast(ds) # use tuple unpacking to extract one dataset
309+
>>> ds2
310+
<xray.Dataset>
311+
Dimensions: (x: 3, y: 2)
312+
Coordinates:
313+
* x (x) int64 0 1 2
314+
* y (y) int64 0 1
315+
Data variables:
316+
a (x, y) int64 1 1 2 2 3 3
317+
b (x, y) int64 5 6 5 6 5 6
268318
"""
269-
# TODO: fixme for coordinate arrays
270-
271319
from .dataarray import DataArray
320+
from .dataset import Dataset
272321

273322
all_indexes = _get_all_indexes(args)
274323
for k, v in all_indexes.items():
275324
if not all(v[0].equals(vi) for vi in v[1:]):
276325
raise ValueError('cannot broadcast arrays: the %s index is not '
277326
'aligned (use xray.align first)' % k)
278327

279-
vars = broadcast_variables(*[a.variable for a in args])
280-
indexes = dict((k, all_indexes[k][0]) for k in vars[0].dims)
328+
common_coords = OrderedDict()
329+
dims_map = OrderedDict()
330+
for arg in args:
331+
for dim in arg.dims:
332+
if dim not in common_coords:
333+
common_coords[dim] = arg.coords[dim].variable
334+
dims_map[dim] = common_coords[dim].size
335+
336+
def _broadcast_array(array):
337+
data = array.variable.expand_dims(dims_map)
338+
coords = OrderedDict(array.coords)
339+
coords.update(common_coords)
340+
dims = tuple(common_coords)
341+
name = array.name
342+
attrs = array.attrs
343+
encoding = array.encoding
344+
return DataArray(data, coords, dims, name, attrs, encoding)
345+
346+
def _broadcast_dataset(ds):
347+
data_vars = OrderedDict()
348+
for k in ds.data_vars:
349+
data_vars[k] = ds.variables[k].expand_dims(dims_map)
350+
351+
coords = OrderedDict(ds.coords)
352+
coords.update(common_coords)
353+
354+
return Dataset(data_vars, coords, ds.attrs)
355+
356+
result = []
357+
for arg in args:
358+
if isinstance(arg, DataArray):
359+
result.append(_broadcast_array(arg))
360+
elif isinstance(arg, Dataset):
361+
result.append(_broadcast_dataset(arg))
362+
else:
363+
raise ValueError('all input must be Dataset or DataArray objects')
281364

282-
arrays = []
283-
for a, v in zip(args, vars):
284-
arr = DataArray(v.values, indexes, v.dims, a.name, a.attrs, a.encoding)
285-
for k, v in a.coords.items():
286-
arr.coords[k] = v
287-
arrays.append(arr)
365+
return tuple(result)
288366

289-
return tuple(arrays)
367+
368+
def broadcast_arrays(*args):
369+
warnings.warn('xray.broadcast_arrays is deprecated: use xray.broadcast '
370+
'instead', DeprecationWarning, stacklevel=2)
371+
return broadcast(*args)

xray/test/test_dataarray.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from copy import deepcopy
55
from textwrap import dedent
66

7-
from xray import (align, broadcast_arrays, Dataset, DataArray,
7+
from xray import (align, broadcast, Dataset, DataArray,
88
Coordinate, Variable)
99
from xray.core.pycompat import iteritems, OrderedDict
1010
from . import TestCase, ReturnItem, source_ndarray, unittest, requires_dask
@@ -1267,7 +1267,7 @@ def test_align_dtype(self):
12671267
def test_broadcast_arrays(self):
12681268
x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x')
12691269
y = DataArray([1, 2], coords=[('b', [3, 4])], name='y')
1270-
x2, y2 = broadcast_arrays(x, y)
1270+
x2, y2 = broadcast(x, y)
12711271
expected_coords = [('a', [-1, -2]), ('b', [3, 4])]
12721272
expected_x2 = DataArray([[1, 1], [2, 2]], expected_coords, name='x')
12731273
expected_y2 = DataArray([[1, 2], [1, 2]], expected_coords, name='y')
@@ -1276,15 +1276,15 @@ def test_broadcast_arrays(self):
12761276

12771277
x = DataArray(np.random.randn(2, 3), dims=['a', 'b'])
12781278
y = DataArray(np.random.randn(3, 2), dims=['b', 'a'])
1279-
x2, y2 = broadcast_arrays(x, y)
1279+
x2, y2 = broadcast(x, y)
12801280
expected_x2 = x
12811281
expected_y2 = y.T
12821282
self.assertDataArrayIdentical(expected_x2, x2)
12831283
self.assertDataArrayIdentical(expected_y2, y2)
12841284

1285+
z = DataArray([1, 2], coords=[('a', [-10, 20])])
12851286
with self.assertRaisesRegexp(ValueError, 'cannot broadcast'):
1286-
z = DataArray([1, 2], coords=[('a', [-10, 20])])
1287-
broadcast_arrays(x, z)
1287+
broadcast(x, z)
12881288

12891289
def test_to_pandas(self):
12901290
# 0d

xray/test/test_dataset.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import numpy as np
1313
import pandas as pd
1414

15-
from xray import (align, concat, conventions, backends, Dataset, DataArray,
16-
Variable, Coordinate, auto_combine, open_dataset,
15+
from xray import (align, broadcast, concat, conventions, backends, Dataset,
16+
DataArray, Variable, Coordinate, auto_combine, open_dataset,
1717
set_options)
1818
from xray.core import indexing, utils
1919
from xray.core.pycompat import iteritems, OrderedDict
@@ -937,6 +937,30 @@ def test_align(self):
937937
with self.assertRaises(TypeError):
938938
align(left, right, foo='bar')
939939

940+
def test_broadcast(self):
941+
ds = Dataset({'foo': 0, 'bar': ('x', [1]), 'baz': ('y', [2, 3])},
942+
{'c': ('x', [4])})
943+
expected = Dataset({'foo': (('x', 'y'), [[0, 0]]),
944+
'bar': (('x', 'y'), [[1, 1]]),
945+
'baz': (('x', 'y'), [[2, 3]])},
946+
{'c': ('x', [4])})
947+
actual, = broadcast(ds)
948+
self.assertDatasetIdentical(expected, actual)
949+
950+
ds_x = Dataset({'foo': ('x', [1])})
951+
ds_y = Dataset({'bar': ('y', [2, 3])})
952+
expected_x = Dataset({'foo': (('x', 'y'), [[1, 1]])})
953+
expected_y = Dataset({'bar': (('x', 'y'), [[2, 3]])})
954+
actual_x, actual_y = broadcast(ds_x, ds_y)
955+
self.assertDatasetIdentical(expected_x, actual_x)
956+
self.assertDatasetIdentical(expected_y, actual_y)
957+
958+
array_y = ds_y['bar']
959+
expected_y = expected_y['bar']
960+
actual_x, actual_y = broadcast(ds_x, array_y)
961+
self.assertDatasetIdentical(expected_x, actual_x)
962+
self.assertDataArrayIdentical(expected_y, actual_y)
963+
940964
def test_variable_indexing(self):
941965
data = create_test_data()
942966
v = data['var1']

0 commit comments

Comments
 (0)