diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index fb917dfb254..6a8cd9c457b 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -126,7 +126,8 @@ def test_dask_distributed_write_netcdf_with_dimensionless_variables( @requires_cftime @requires_netCDF4 -def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path): T = xr.cftime_range("20010101", "20010501", calendar="360_day") Lon = np.arange(100) data = np.random.random((T.size, Lon.size)) @@ -135,9 +136,55 @@ def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): da.to_netcdf(file_path) with cluster() as (s, [a, b]): with Client(s["address"]): - for parallel in (False, True): - with xr.open_mfdataset(file_path, parallel=parallel) as tf: - assert_identical(tf["test"], da) + with xr.open_mfdataset(file_path, parallel=parallel) as tf: + assert_identical(tf["test"], da) + + +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + with cluster() as (s, [a, b]): + with Client(s["address"]): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) + + +# TODO: move this to test_backends.py +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: + with dask.config.set(scheduler=get): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS)