Skip to content

Commit

Permalink
Load nonindex coords ahead of concat() (#1551)
Browse files Browse the repository at this point in the history
* Load non-index coords to memory ahead of concat

* Update unit test after #1522

* Minimise loads on concat. Extend new concat logic to data_vars.

* Trivial tweaks

* Added unit tests

Fix loads when vars are found different halfway through

* Add xfail for #1586

* Revert "Add xfail for #1586"

This reverts commit f99313c.
  • Loading branch information
crusaderky authored and Joe Hamman committed Oct 9, 2017
1 parent 772f7e0 commit 14b5f1c
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 43 deletions.
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ Bug fixes
``rtol`` arguments when called on ``DataArray`` objects.
By `Stephan Hoyer <https://github.com/shoyer>`_.

- :py:func:`~xarray.concat` was computing variables that aren't in memory
(e.g. dask-based) multiple times; :py:func:`~xarray.open_mfdataset`
was loading them multiple times from disk. Now, both functions will instead
load them at most once and, if they do, store them in memory in the
concatenated array/dataset (:issue:`1521`).
By `Guido Imperiale <https://github.com/crusaderky>`_.

- xarray ``quantile`` methods now properly raise a ``TypeError`` when applied to
objects with data stored as ``dask`` arrays (:issue:`1529`).
By `Joe Hamman <https://github.com/jhamman>`_.
Expand Down
113 changes: 72 additions & 41 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,68 +148,85 @@ def _calc_concat_over(datasets, dim, data_vars, coords):
Determine which dataset variables need to be concatenated in the result,
and which can simply be taken from the first dataset.
"""
def process_subset_opt(opt, subset):
if subset == 'coords':
subset_long_name = 'coordinates'
else:
subset_long_name = 'data variables'
# Return values
concat_over = set()
equals = {}

if dim in datasets[0]:
concat_over.add(dim)
for ds in datasets:
concat_over.update(k for k, v in ds.variables.items()
if dim in v.dims)

def process_subset_opt(opt, subset):
if isinstance(opt, basestring):
if opt == 'different':
def differs(vname):
# simple helper function which compares a variable
# across all datasets and indicates whether that
# variable differs or not.
v = datasets[0].variables[vname]
return any(not ds.variables[vname].equals(v)
for ds in datasets[1:])
# all nonindexes that are not the same in each dataset
concat_new = set(k for k in getattr(datasets[0], subset)
if k not in concat_over and differs(k))
for k in getattr(datasets[0], subset):
if k not in concat_over:
# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk while
# keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not v_lhs.equals(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
else:
equals[k] = True

elif opt == 'all':
concat_new = (set(getattr(datasets[0], subset)) -
set(datasets[0].dims))
concat_over.update(set(getattr(datasets[0], subset)) -
set(datasets[0].dims))
elif opt == 'minimal':
concat_new = set()
pass
else:
raise ValueError("unexpected value for concat_%s: %s"
% (subset, opt))
raise ValueError("unexpected value for %s: %s" % (subset, opt))
else:
invalid_vars = [k for k in opt
if k not in getattr(datasets[0], subset)]
if invalid_vars:
raise ValueError('some variables in %s are not '
'%s on the first dataset: %s'
% (subset, subset_long_name, invalid_vars))
concat_new = set(opt)
return concat_new
if subset == 'coords':
raise ValueError(
'some variables in coords are not coordinates on '
'the first dataset: %s' % invalid_vars)
else:
raise ValueError(
'some variables in data_vars are not data variables on '
'the first dataset: %s' % invalid_vars)
concat_over.update(opt)

concat_over = set()
for ds in datasets:
concat_over.update(k for k, v in ds.variables.items()
if dim in v.dims)
concat_over.update(process_subset_opt(data_vars, 'data_vars'))
concat_over.update(process_subset_opt(coords, 'coords'))
if dim in datasets[0]:
concat_over.add(dim)
return concat_over
process_subset_opt(data_vars, 'data_vars')
process_subset_opt(coords, 'coords')
return concat_over, equals


def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
from .dataset import Dataset, as_dataset
from .dataset import Dataset

if compat not in ['equals', 'identical']:
raise ValueError("compat=%r invalid: must be 'equals' "
"or 'identical'" % compat)

dim, coord = _calc_concat_dim_coord(dim)
datasets = [as_dataset(ds) for ds in datasets]
# Make sure we're working on a copy (we'll be loading variables)
datasets = [ds.copy() for ds in datasets]
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])

concat_over = _calc_concat_over(datasets, dim, data_vars, coords)
concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)

def insert_result_variable(k, v):
assert isinstance(v, Variable)
Expand Down Expand Up @@ -239,11 +256,25 @@ def insert_result_variable(k, v):
elif (k in result_coord_names) != (k in ds.coords):
raise ValueError('%r is a coordinate in some datasets but not '
'others' % k)
elif (k in result_vars and k != dim and
not getattr(v, compat)(result_vars[k])):
verb = 'equal' if compat == 'equals' else compat
raise ValueError(
'variable %r not %s across datasets' % (k, verb))
elif k in result_vars and k != dim:
# Don't use Variable.identical as it internally invokes
# Variable.equals, and we may already know the answer
if compat == 'identical' and not utils.dict_equiv(
v.attrs, result_vars[k].attrs):
raise ValueError(
'variable %s not identical across datasets' % k)

# Proceed with equals()
try:
# May be populated when using the "different" method
is_equal = equals[k]
except KeyError:
result_vars[k].load()
is_equal = v.equals(result_vars[k])
if not is_equal:
raise ValueError(
'variable %s not equal across datasets' % k)


# we've already verified everything is consistent; now, calculate
# shared dimension sizes so we can expand the necessary variables
Expand Down
76 changes: 74 additions & 2 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,77 @@ def test_lazy_array(self):
actual = xr.concat([v[:2], v[2:]], 'x')
self.assertLazyAndAllClose(u, actual)

def test_concat_loads_variables(self):
# Test that concat() computes not-in-memory variables at most once
# and loads them in the output, while leaving the input unaltered.
d1 = build_dask_array('d1')
c1 = build_dask_array('c1')
d2 = build_dask_array('d2')
c2 = build_dask_array('c2')
d3 = build_dask_array('d3')
c3 = build_dask_array('c3')
# Note: c is a non-index coord.
# Index coords are loaded by IndexVariable.__init__.
ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)})
ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)})
ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)})

assert kernel_call_count == 0
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
coords='different')
# each kernel is computed exactly once
assert kernel_call_count == 6
# variables are loaded in the output
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='all', coords='all')
# no extra kernel calls
assert kernel_call_count == 6
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c'])
# no extra kernel calls
assert kernel_call_count == 6
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[])
# variables are loaded once as we are validing that they're identical
assert kernel_call_count == 12
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
coords='different', compat='identical')
# compat=identical doesn't do any more kernel calls than compat=equals
assert kernel_call_count == 18
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

# When the test for different turns true halfway through,
# stop computing variables as it would not have any benefit
ds4 = Dataset(data_vars={'d': ('x', [2.0])}, coords={'c': ('x', [2.0])})
out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different',
coords='different')
# the variables of ds1 and ds2 were computed, but those of ds3 didn't
assert kernel_call_count == 22
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)
# the data of ds1 and ds2 was loaded into numpy and then
# concatenated to the data of ds3. Thus, only ds3 is computed now.
out.compute()
assert kernel_call_count == 24

# Finally, test that riginals are unaltered
assert ds1['d'].data is d1
assert ds1['c'].data is c1
assert ds2['d'].data is d2
assert ds2['c'].data is c2
assert ds3['d'].data is d3
assert ds3['c'].data is c3

def test_groupby(self):
if LooseVersion(dask.__version__) == LooseVersion('0.15.3'):
pytest.xfail('upstream bug in dask: '
Expand Down Expand Up @@ -517,10 +588,11 @@ def test_dask_kwargs_dataset(method):
kernel_call_count = 0


def kernel():
def kernel(name):
"""Dask kernel to test pickling/unpickling and __repr__.
Must be global to make it pickleable.
"""
print("kernel(%s)" % name)
global kernel_call_count
kernel_call_count += 1
return np.ones(1, dtype=np.int64)
Expand All @@ -530,5 +602,5 @@ def build_dask_array(name):
global kernel_call_count
kernel_call_count = 0
return dask.array.Array(
dask={(name, 0): (kernel, )}, name=name,
dask={(name, 0): (kernel, name)}, name=name,
chunks=((1,),), dtype=np.int64)

0 comments on commit 14b5f1c

Please sign in to comment.