diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 539fdbda1c7..007b9640e20 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -170,14 +170,17 @@ def process_subset_opt(opt, subset): # keeping the RAM footprint low. v_lhs = datasets[0].variables[k].load() # We'll need to know later on if variables are equal. + computed = [] for ds_rhs in datasets[1:]: v_rhs = ds_rhs.variables[k].compute() + computed.append(v_rhs) if not v_lhs.equals(v_rhs): concat_over.add(k) equals[k] = False - # rhs variable is not to be discarded, therefore - # avoid re-computing it in the future - ds_rhs.variables[k].data = v_rhs.data + # computed variables are not to be re-computed + # again in the future + for ds, v in zip(datasets[1:], computed): + ds.variables[k].data = v.data break else: equals[k] = True diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 98420452f72..a73d3913bde 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -249,19 +249,76 @@ def test_lazy_array(self): actual = xr.concat([v[:2], v[2:]], 'x') self.assertLazyAndAllClose(u, actual) - def test_concat_loads_coords(self): - # Test that concat() computes dask-based, non-index - # coordinates exactly once and loads them in the output, - # while leaving the input unaltered. - y = build_dask_array('y') - ds1 = Dataset(coords={'x': [1], 'y': ('x', y)}) - ds2 = Dataset(coords={'x': [1], 'y': ('x', [2.0])}) + def test_concat_loads_variables(self): + # Test that concat() computes not-in-memory variables at most once + # and loads them in the output, while leaving the input unaltered. + d1 = build_dask_array('d1') + c1 = build_dask_array('c1') + d2 = build_dask_array('d2') + c2 = build_dask_array('c2') + d3 = build_dask_array('d3') + c3 = build_dask_array('c3') + # Note: c is a non-index coord. + # Index coords are loaded by IndexVariable.__init__. + ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)}) + ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)}) + ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)}) + assert kernel_call_count == 0 - ds3 = xr.concat([ds1, ds2], dim='z') - assert kernel_call_count == 1 - assert ds1['y'].data is y - assert isinstance(ds3['y'].data, np.ndarray) - assert ds3['y'].values.tolist() == [[1.0], [2.0]] + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', + coords='different') + # each kernel is computed exactly once + assert kernel_call_count == 6 + # variables are loaded in the output + assert isinstance(out['d'].data, np.ndarray) + assert isinstance(out['c'].data, np.ndarray) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='all', coords='all') + # no extra kernel calls + assert kernel_call_count == 6 + assert isinstance(out['d'].data, dask.array.Array) + assert isinstance(out['c'].data, dask.array.Array) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c']) + # no extra kernel calls + assert kernel_call_count == 6 + assert isinstance(out['d'].data, dask.array.Array) + assert isinstance(out['c'].data, dask.array.Array) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[]) + # variables are loaded once as we are validing that they're identical + assert kernel_call_count == 12 + assert isinstance(out['d'].data, np.ndarray) + assert isinstance(out['c'].data, np.ndarray) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', + coords='different', compat='identical') + # compat=identical doesn't do any more kernel calls than compat=equals + assert kernel_call_count == 18 + assert isinstance(out['d'].data, np.ndarray) + assert isinstance(out['c'].data, np.ndarray) + + # When the test for different turns true halfway through, + # stop computing variables as it would not have any benefit + ds4 = Dataset(data_vars={'d': ('x', [2.0])}, coords={'c': ('x', [2.0])}) + out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different', + coords='different') + # the variables of ds1 and ds2 were computed, but those of ds3 didn't + assert kernel_call_count == 22 + assert isinstance(out['d'].data, dask.array.Array) + assert isinstance(out['c'].data, dask.array.Array) + # the data of ds1 and ds2 was loaded into numpy and then + # concatenated to the data of ds3. Thus, only ds3 is computed now. + out.compute() + assert kernel_call_count == 24 + + # Finally, test that riginals are unaltered + assert ds1['d'].data is d1 + assert ds1['c'].data is c1 + assert ds2['d'].data is d2 + assert ds2['c'].data is c2 + assert ds3['d'].data is d3 + assert ds3['c'].data is c3 def test_groupby(self): u = self.eager_array @@ -529,10 +586,11 @@ def test_dask_kwargs_dataset(method): kernel_call_count = 0 -def kernel(): +def kernel(name): """Dask kernel to test pickling/unpickling and __repr__. Must be global to make it pickleable. """ + print("kernel(%s)" % name) global kernel_call_count kernel_call_count += 1 return np.ones(1, dtype=np.int64) @@ -542,5 +600,5 @@ def build_dask_array(name): global kernel_call_count kernel_call_count = 0 return dask.array.Array( - dask={(name, 0): (kernel, )}, name=name, + dask={(name, 0): (kernel, name)}, name=name, chunks=((1,),), dtype=np.int64)