Skip to content

Commit

Permalink
Added unit tests
Browse files Browse the repository at this point in the history
Fix loads when vars are found different halfway through
  • Loading branch information
gimperiale committed Sep 24, 2017
1 parent 68d51c6 commit 1c30474
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 17 deletions.
9 changes: 6 additions & 3 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,17 @@ def process_subset_opt(opt, subset):
# keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not v_lhs.equals(v_rhs):
concat_over.add(k)
equals[k] = False
# rhs variable is not to be discarded, therefore
# avoid re-computing it in the future
ds_rhs.variables[k].data = v_rhs.data
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
else:
equals[k] = True
Expand Down
86 changes: 72 additions & 14 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,76 @@ def test_lazy_array(self):
actual = xr.concat([v[:2], v[2:]], 'x')
self.assertLazyAndAllClose(u, actual)

def test_concat_loads_coords(self):
# Test that concat() computes dask-based, non-index
# coordinates exactly once and loads them in the output,
# while leaving the input unaltered.
y = build_dask_array('y')
ds1 = Dataset(coords={'x': [1], 'y': ('x', y)})
ds2 = Dataset(coords={'x': [1], 'y': ('x', [2.0])})
def test_concat_loads_variables(self):
# Test that concat() computes not-in-memory variables at most once
# and loads them in the output, while leaving the input unaltered.
d1 = build_dask_array('d1')
c1 = build_dask_array('c1')
d2 = build_dask_array('d2')
c2 = build_dask_array('c2')
d3 = build_dask_array('d3')
c3 = build_dask_array('c3')
# Note: c is a non-index coord.
# Index coords are loaded by IndexVariable.__init__.
ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)})
ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)})
ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)})

assert kernel_call_count == 0
ds3 = xr.concat([ds1, ds2], dim='z')
assert kernel_call_count == 1
assert ds1['y'].data is y
assert isinstance(ds3['y'].data, np.ndarray)
assert ds3['y'].values.tolist() == [[1.0], [2.0]]
out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
coords='different')
# each kernel is computed exactly once
assert kernel_call_count == 6
# variables are loaded in the output
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='all', coords='all')
# no extra kernel calls
assert kernel_call_count == 6
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c'])
# no extra kernel calls
assert kernel_call_count == 6
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[])
# variables are loaded once as we are validing that they're identical
assert kernel_call_count == 12
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different',
coords='different', compat='identical')
# compat=identical doesn't do any more kernel calls than compat=equals
assert kernel_call_count == 18
assert isinstance(out['d'].data, np.ndarray)
assert isinstance(out['c'].data, np.ndarray)

# When the test for different turns true halfway through,
# stop computing variables as it would not have any benefit
ds4 = Dataset(data_vars={'d': ('x', [2.0])}, coords={'c': ('x', [2.0])})
out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different',
coords='different')
# the variables of ds1 and ds2 were computed, but those of ds3 didn't
assert kernel_call_count == 22
assert isinstance(out['d'].data, dask.array.Array)
assert isinstance(out['c'].data, dask.array.Array)
# the data of ds1 and ds2 was loaded into numpy and then
# concatenated to the data of ds3. Thus, only ds3 is computed now.
out.compute()
assert kernel_call_count == 24

# Finally, test that riginals are unaltered
assert ds1['d'].data is d1
assert ds1['c'].data is c1
assert ds2['d'].data is d2
assert ds2['c'].data is c2
assert ds3['d'].data is d3
assert ds3['c'].data is c3

def test_groupby(self):
u = self.eager_array
Expand Down Expand Up @@ -529,10 +586,11 @@ def test_dask_kwargs_dataset(method):
kernel_call_count = 0


def kernel():
def kernel(name):
"""Dask kernel to test pickling/unpickling and __repr__.
Must be global to make it pickleable.
"""
print("kernel(%s)" % name)
global kernel_call_count
kernel_call_count += 1
return np.ones(1, dtype=np.int64)
Expand All @@ -542,5 +600,5 @@ def build_dask_array(name):
global kernel_call_count
kernel_call_count = 0
return dask.array.Array(
dask={(name, 0): (kernel, )}, name=name,
dask={(name, 0): (kernel, name)}, name=name,
chunks=((1,),), dtype=np.int64)

0 comments on commit 1c30474

Please sign in to comment.