Skip to content

Commit

Permalink
Add fill_value for concat and auto_combine (#2964)
Browse files Browse the repository at this point in the history
* add fill_value option for concat and auto_combine

* add tests for fill_value in concat and auto_combine

* remove errant whitespace

* add fill_value description to doc-string

* add missing assert
  • Loading branch information
zdgriffith authored and shoyer committed May 27, 2019
1 parent 7edf2e2 commit 6dc8b60
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 20 deletions.
55 changes: 35 additions & 20 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from . import utils
from . import utils, dtypes
from .alignment import align
from .merge import merge
from .variable import IndexVariable, Variable, as_variable
Expand All @@ -14,7 +14,7 @@

def concat(objs, dim=None, data_vars='all', coords='different',
compat='equals', positions=None, indexers=None, mode=None,
concat_over=None):
concat_over=None, fill_value=dtypes.NA):
"""Concatenate xarray objects along a new or existing dimension.
Parameters
Expand Down Expand Up @@ -66,6 +66,8 @@ def concat(objs, dim=None, data_vars='all', coords='different',
List of integer arrays which specifies the integer positions to which
to assign each dataset along the concatenated dimension. If not
supplied, objects are concatenated in the provided order.
fill_value : scalar, optional
Value to use for newly missing values
indexers, mode, concat_over : deprecated
Returns
Expand Down Expand Up @@ -117,7 +119,7 @@ def concat(objs, dim=None, data_vars='all', coords='different',
else:
raise TypeError('can only concatenate xarray Dataset and DataArray '
'objects, got %s' % type(first_obj))
return f(objs, dim, data_vars, coords, compat, positions)
return f(objs, dim, data_vars, coords, compat, positions, fill_value)


def _calc_concat_dim_coord(dim):
Expand Down Expand Up @@ -212,7 +214,8 @@ def process_subset_opt(opt, subset):
return concat_over, equals


def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
def _dataset_concat(datasets, dim, data_vars, coords, compat, positions,
fill_value=dtypes.NA):
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
Expand All @@ -225,7 +228,8 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
dim, coord = _calc_concat_dim_coord(dim)
# Make sure we're working on a copy (we'll be loading variables)
datasets = [ds.copy() for ds in datasets]
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])
datasets = align(*datasets, join='outer', copy=False, exclude=[dim],
fill_value=fill_value)

concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)

Expand Down Expand Up @@ -317,7 +321,7 @@ def ensure_common_dims(vars):


def _dataarray_concat(arrays, dim, data_vars, coords, compat,
positions):
positions, fill_value=dtypes.NA):
arrays = list(arrays)

if data_vars != 'all':
Expand All @@ -336,14 +340,15 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat,
datasets.append(arr._to_temp_dataset())

ds = _dataset_concat(datasets, dim, data_vars, coords, compat,
positions)
positions, fill_value)
result = arrays[0]._from_temp_dataset(ds, name)

result.name = result_name(arrays)
return result


def _auto_concat(datasets, dim=None, data_vars='all', coords='different'):
def _auto_concat(datasets, dim=None, data_vars='all', coords='different',
fill_value=dtypes.NA):
if len(datasets) == 1 and dim is None:
# There is nothing more to combine, so kick out early.
return datasets[0]
Expand All @@ -366,7 +371,8 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different'):
'supply the ``concat_dim`` argument '
'explicitly')
dim, = concat_dims
return concat(datasets, dim=dim, data_vars=data_vars, coords=coords)
return concat(datasets, dim=dim, data_vars=data_vars,
coords=coords, fill_value=fill_value)


_CONCAT_DIM_DEFAULT = utils.ReprObject('<inferred>')
Expand Down Expand Up @@ -442,7 +448,8 @@ def _check_shape_tile_ids(combined_tile_ids):


def _combine_nd(combined_ids, concat_dims, data_vars='all',
coords='different', compat='no_conflicts'):
coords='different', compat='no_conflicts',
fill_value=dtypes.NA):
"""
Concatenates and merges an N-dimensional structure of datasets.
Expand Down Expand Up @@ -472,13 +479,14 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all',
dim=concat_dim,
data_vars=data_vars,
coords=coords,
compat=compat)
compat=compat,
fill_value=fill_value)
combined_ds = list(combined_ids.values())[0]
return combined_ds


def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars,
coords, compat):
coords, compat, fill_value=dtypes.NA):
# Group into lines of datasets which must be combined along dim
# need to sort by _new_tile_id first for groupby to work
# TODO remove all these sorted OrderedDicts once python >= 3.6 only
Expand All @@ -490,7 +498,8 @@ def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars,
combined_ids = OrderedDict(sorted(group))
datasets = combined_ids.values()
new_combined_ids[new_id] = _auto_combine_1d(datasets, dim, compat,
data_vars, coords)
data_vars, coords,
fill_value)
return new_combined_ids


Expand All @@ -500,18 +509,20 @@ def vars_as_keys(ds):

def _auto_combine_1d(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
compat='no_conflicts',
data_vars='all', coords='different'):
data_vars='all', coords='different',
fill_value=dtypes.NA):
# This is just the old auto_combine function (which only worked along 1D)
if concat_dim is not None:
dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim
sorted_datasets = sorted(datasets, key=vars_as_keys)
grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys)
concatenated = [_auto_concat(list(ds_group), dim=dim,
data_vars=data_vars, coords=coords)
data_vars=data_vars, coords=coords,
fill_value=fill_value)
for id, ds_group in grouped_by_vars]
else:
concatenated = datasets
merged = merge(concatenated, compat=compat)
merged = merge(concatenated, compat=compat, fill_value=fill_value)
return merged


Expand All @@ -521,7 +532,7 @@ def _new_tile_id(single_id_ds_pair):


def _auto_combine(datasets, concat_dims, compat, data_vars, coords,
infer_order_from_coords, ids):
infer_order_from_coords, ids, fill_value=dtypes.NA):
"""
Calls logic to decide concatenation order before concatenating.
"""
Expand Down Expand Up @@ -550,12 +561,14 @@ def _auto_combine(datasets, concat_dims, compat, data_vars, coords,

# Repeatedly concatenate then merge along each dimension
combined = _combine_nd(combined_ids, concat_dims, compat=compat,
data_vars=data_vars, coords=coords)
data_vars=data_vars, coords=coords,
fill_value=fill_value)
return combined


def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
compat='no_conflicts', data_vars='all', coords='different'):
compat='no_conflicts', data_vars='all', coords='different',
fill_value=dtypes.NA):
"""Attempt to auto-magically combine the given datasets into one.
This method attempts to combine a list of datasets into a single entity by
inspecting metadata and using a combination of concat and merge.
Expand Down Expand Up @@ -596,6 +609,8 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
Details are in the documentation of concat
coords : {'minimal', 'different', 'all' or list of str}, optional
Details are in the documentation of conca
fill_value : scalar, optional
Value to use for newly missing values
Returns
-------
Expand All @@ -622,4 +637,4 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
return _auto_combine(datasets, concat_dims=concat_dims, compat=compat,
data_vars=data_vars, coords=coords,
infer_order_from_coords=infer_order_from_coords,
ids=False)
ids=False, fill_value=fill_value)
42 changes: 42 additions & 0 deletions xarray/tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from xarray import DataArray, Dataset, Variable, auto_combine, concat
from xarray.core import dtypes
from xarray.core.combine import (
_auto_combine, _auto_combine_1d, _auto_combine_all_along_first_dim,
_check_shape_tile_ids, _combine_nd, _infer_concat_order_from_positions,
Expand Down Expand Up @@ -237,6 +238,20 @@ def test_concat_multiindex(self):
assert expected.equals(actual)
assert isinstance(actual.x.to_index(), pd.MultiIndex)

@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_concat_fill_value(self, fill_value):
datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}),
Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})]
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = Dataset({'a': (('t', 'x'),
[[fill_value, 2, 3], [1, 2, fill_value]])},
{'x': [0, 1, 2]})
actual = concat(datasets, dim='t', fill_value=fill_value)
assert_identical(actual, expected)


class TestConcatDataArray:
def test_concat(self):
Expand Down Expand Up @@ -306,6 +321,19 @@ def test_concat_lazy(self):
assert combined.shape == (2, 3, 3)
assert combined.dims == ('z', 'x', 'y')

@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_concat_fill_value(self, fill_value):
foo = DataArray([1, 2], coords=[('x', [1, 2])])
bar = DataArray([1, 2], coords=[('x', [1, 3])])
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = DataArray([[1, 2, fill_value], [1, fill_value, 2]],
dims=['y', 'x'], coords={'x': [1, 2, 3]})
actual = concat((foo, bar), dim='y', fill_value=fill_value)
assert_identical(actual, expected)


class TestAutoCombine:

Expand Down Expand Up @@ -417,6 +445,20 @@ def test_auto_combine_no_concat(self):
{'baz': [100]})
assert_identical(expected, actual)

@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_auto_combine_fill_value(self, fill_value):
datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}),
Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})]
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = Dataset({'a': (('t', 'x'),
[[fill_value, 2, 3], [1, 2, fill_value]])},
{'x': [0, 1, 2]})
actual = auto_combine(datasets, concat_dim='t', fill_value=fill_value)
assert_identical(expected, actual)


def assert_combined_tile_ids_equal(dict1, dict2):
assert len(dict1) == len(dict2)
Expand Down

0 comments on commit 6dc8b60

Please sign in to comment.