Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load nonindex coords ahead of concat() #1551

Merged
merged 13 commits into from
Oct 9, 2017
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ Bug fixes
``rtol`` arguments when called on ``DataArray`` objects.
By `Stephan Hoyer <https://github.com/shoyer>`_.

- :py:func:`~xarray.concat` was computing variables that aren't in memory
(e.g. dask-based) multiple times; :py:func:`~xarray.open_mfdataset`
was loading them multiple times from disk. Now, both functions will instead
load them at most once and, if they do, store them in memory in the
concatenated array/dataset (:issue:`1521`).
By `Guido Imperiale <https://github.com/crusaderky>`_.

- xarray ``quantile`` methods now properly raise a ``TypeError`` when applied to
objects with data stored as ``dask`` arrays (:issue:`1529`).
By `Joe Hamman <https://github.com/jhamman>`_.
Expand Down
113 changes: 72 additions & 41 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,68 +148,85 @@ def _calc_concat_over(datasets, dim, data_vars, coords):
Determine which dataset variables need to be concatenated in the result,
and which can simply be taken from the first dataset.
"""
def process_subset_opt(opt, subset):
if subset == 'coords':
subset_long_name = 'coordinates'
else:
subset_long_name = 'data variables'
# Return values
concat_over = set()
equals = {}

if dim in datasets[0]:
concat_over.add(dim)
for ds in datasets:
concat_over.update(k for k, v in ds.variables.items()
if dim in v.dims)

def process_subset_opt(opt, subset):
if isinstance(opt, basestring):
if opt == 'different':
def differs(vname):
# simple helper function which compares a variable
# across all datasets and indicates whether that
# variable differs or not.
v = datasets[0].variables[vname]
return any(not ds.variables[vname].equals(v)
for ds in datasets[1:])
# all nonindexes that are not the same in each dataset
concat_new = set(k for k in getattr(datasets[0], subset)
if k not in concat_over and differs(k))
for k in getattr(datasets[0], subset):
if k not in concat_over:
# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk while
# keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not v_lhs.equals(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
else:
equals[k] = True

elif opt == 'all':
concat_new = (set(getattr(datasets[0], subset)) -
set(datasets[0].dims))
concat_over.update(set(getattr(datasets[0], subset)) -
set(datasets[0].dims))
elif opt == 'minimal':
concat_new = set()
pass
else:
raise ValueError("unexpected value for concat_%s: %s"
% (subset, opt))
raise ValueError("unexpected value for %s: %s" % (subset, opt))
else:
invalid_vars = [k for k in opt
if k not in getattr(datasets[0], subset)]
if invalid_vars:
raise ValueError('some variables in %s are not '
'%s on the first dataset: %s'
% (subset, subset_long_name, invalid_vars))
concat_new = set(opt)
return concat_new
if subset == 'coords':
raise ValueError(
'some variables in coords are not coordinates on '
'the first dataset: %s' % invalid_vars)
else:
raise ValueError(
'some variables in data_vars are not data variables on '
'the first dataset: %s' % invalid_vars)
concat_over.update(opt)

concat_over = set()
for ds in datasets:
concat_over.update(k for k, v in ds.variables.items()
if dim in v.dims)
concat_over.update(process_subset_opt(data_vars, 'data_vars'))
concat_over.update(process_subset_opt(coords, 'coords'))
if dim in datasets[0]:
concat_over.add(dim)
return concat_over
process_subset_opt(data_vars, 'data_vars')
process_subset_opt(coords, 'coords')
return concat_over, equals


def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
from .dataset import Dataset, as_dataset
from .dataset import Dataset

if compat not in ['equals', 'identical']:
raise ValueError("compat=%r invalid: must be 'equals' "
"or 'identical'" % compat)

dim, coord = _calc_concat_dim_coord(dim)
datasets = [as_dataset(ds) for ds in datasets]
# Make sure we're working on a copy (we'll be loading variables)
datasets = [ds.copy() for ds in datasets]
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])

concat_over = _calc_concat_over(datasets, dim, data_vars, coords)
concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)

def insert_result_variable(k, v):
assert isinstance(v, Variable)
Expand Down Expand Up @@ -239,11 +256,25 @@ def insert_result_variable(k, v):
elif (k in result_coord_names) != (k in ds.coords):
raise ValueError('%r is a coordinate in some datasets but not '
'others' % k)
elif (k in result_vars and k != dim and
not getattr(v, compat)(result_vars[k])):
verb = 'equal' if compat == 'equals' else compat
raise ValueError(
'variable %r not %s across datasets' % (k, verb))
elif k in result_vars and k != dim:
# Don't use Variable.identical as it internally invokes
# Variable.equals, and we may already know the answer
if compat == 'identical' and not utils.dict_equiv(
v.attrs, result_vars[k].attrs):
raise ValueError(
'variable %s not identical across datasets' % k)

# Proceed with equals()
try:
# May be populated when using the "different" method
is_equal = equals[k]
except KeyError:
result_vars[k].load()
is_equal = v.equals(result_vars[k])
if not is_equal:
raise ValueError(
'variable %s not equal across datasets' % k)


# we've already verified everything is consistent; now, calculate
# shared dimension sizes so we can expand the necessary variables
Expand Down
2 changes: 2 additions & 0 deletions xarray/tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import pytest

from xarray import Dataset, DataArray, auto_combine, concat, Variable
from xarray.core.pycompat import iteritems, OrderedDict
Expand Down Expand Up @@ -268,6 +269,7 @@ def test_concat(self):
with self.assertRaisesRegexp(ValueError, 'not a valid argument'):
concat([foo, bar], dim='w', data_vars='minimal')

@pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/1586')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just merged your other PR -- can you remove this now?

def test_concat_encoding(self):
# Regression test for GH1297
ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))),
Expand Down
76 changes: 74 additions & 2 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,77 @@ def test_lazy_array(self):
actual = xr.concat([v[:2], v[2:]], 'x')
self.assertLazyAndAllClose(u, actual)

def test_concat_loads_variables(self):
# Test that concat() computes not-in-memory variables at most once
# and loads them in the output, while leaving the input unaltered.
d1 = build_dask_array('d1')
c1 = build_dask_array('c1')
d2 = build_dask_array('d2')
c2 = build_dask_array('c2')
d3 = build_dask_array('d3')
c3 = build_dask_array('c3')
# Note: c is a non-index coord.
# Index coords are loaded by IndexVariable.__init__.
ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)})
ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)})
ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)})

assert kernel_call_count == 0
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
coords='different')
# each kernel is computed exactly once
assert kernel_call_count == 6
# variables are loaded in the output
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='all', coords='all')
# no extra kernel calls
assert kernel_call_count == 6
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c'])
# no extra kernel calls
assert kernel_call_count == 6
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[])
# variables are loaded once as we are validing that they're identical
assert kernel_call_count == 12
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
coords='different', compat='identical')
# compat=identical doesn't do any more kernel calls than compat=equals
assert kernel_call_count == 18
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

# When the test for different turns true halfway through,
# stop computing variables as it would not have any benefit
ds4 = Dataset(data_vars={'d': ('x', [2.0])}, coords={'c': ('x', [2.0])})
out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different',
coords='different')
# the variables of ds1 and ds2 were computed, but those of ds3 didn't
assert kernel_call_count == 22
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)
# the data of ds1 and ds2 was loaded into numpy and then
# concatenated to the data of ds3. Thus, only ds3 is computed now.
out.compute()
assert kernel_call_count == 24

# Finally, test that riginals are unaltered
assert ds1['d'].data is d1
assert ds1['c'].data is c1
assert ds2['d'].data is d2
assert ds2['c'].data is c2
assert ds3['d'].data is d3
assert ds3['c'].data is c3

def test_groupby(self):
if LooseVersion(dask.__version__) == LooseVersion('0.15.3'):
pytest.xfail('upstream bug in dask: '
Expand Down Expand Up @@ -517,10 +588,11 @@ def test_dask_kwargs_dataset(method):
kernel_call_count = 0


def kernel():
def kernel(name):
"""Dask kernel to test pickling/unpickling and __repr__.
Must be global to make it pickleable.
"""
print("kernel(%s)" % name)
global kernel_call_count
kernel_call_count += 1
return np.ones(1, dtype=np.int64)
Expand All @@ -530,5 +602,5 @@ def build_dask_array(name):
global kernel_call_count
kernel_call_count = 0
return dask.array.Array(
dask={(name, 0): (kernel, )}, name=name,
dask={(name, 0): (kernel, name)}, name=name,
chunks=((1,),), dtype=np.int64)