Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scalar_level in MultiIndex #1426

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
merge_data_and_coords)
from .utils import (Frozen, SortedKeysDict, maybe_wrap_array, hashable,
decode_numpy_dict_values, ensure_us_time_resolution)
from .variable import (Variable, as_variable, IndexVariable,
broadcast_variables)
from .variable import Variable, as_variable, IndexVariable, broadcast_variables
from .pycompat import (iteritems, basestring, OrderedDict,
integer_types, dask_array_type, range)
from .options import OPTIONS
Expand Down Expand Up @@ -576,21 +575,16 @@ def _replace_vars_and_dims(self, variables, coord_names=None, dims=None,
return obj

def _replace_indexes(self, indexes):
"""
Make some index_level to scalar_level.
indexes: mapping from dimension name to new index.
"""
if not len(indexes):
return self
variables = self._variables.copy()
for name, idx in indexes.items():
variables[name] = IndexVariable(name, idx)
obj = self._replace_vars_and_dims(variables)

# switch from dimension to level names, if necessary
dim_names = {}
for dim, idx in indexes.items():
if not isinstance(idx, pd.MultiIndex) and idx.name != dim:
dim_names[dim] = idx.name
if dim_names:
obj = obj.rename(dim_names)
return obj
variables[name] = variables[name].reset_levels(idx.names)
return self._replace_vars_and_dims(variables)

def copy(self, deep=False):
"""Returns a copy of this dataset.
Expand Down Expand Up @@ -627,7 +621,7 @@ def _level_coords(self):
for cname in self._coord_names:
var = self.variables[cname]
if var.ndim == 1:
level_names = var.to_index_variable().level_names
level_names = var.all_level_names
if level_names is not None:
dim, = var.dims
level_coords.update({lname: dim for lname in level_names})
Expand Down Expand Up @@ -1127,10 +1121,7 @@ def isel(self, drop=False, **indexers):
Dataset.isel_points
DataArray.isel
"""
invalid = [k for k in indexers if k not in self.dims]
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

indexers = indexing.get_dim_pos_indexers(self, indexers)
# all indexers should be int, slice or np.ndarrays
indexers = [(k, (np.asarray(v)
if not isinstance(v, integer_types + (slice,))
Expand Down Expand Up @@ -1771,8 +1762,8 @@ def reorder_levels(self, inplace=False, **dim_order):
index = coord.to_index()
if not isinstance(index, pd.MultiIndex):
raise ValueError("coordinate %r has no MultiIndex" % dim)
replace_variables[dim] = IndexVariable(coord.dims,
index.reorder_levels(order))
replace_variables[dim] = IndexVariable(
coord.dims, index.reorder_levels(order))
variables = self._variables.copy()
variables.update(replace_variables)
return self._replace_vars_and_dims(variables, inplace=inplace)
Expand All @@ -1790,7 +1781,7 @@ def _stack_once(self, dims, new_dim):
variables[name] = stacked_var
else:
variables[name] = var.copy(deep=False)

# TODO move to IndexVariable method
# consider dropping levels that are unused?
levels = [self.get_index(dim) for dim in dims]
if hasattr(pd, 'RangeIndex'):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _summarize_coord_levels(coord, col_width, marker=u'-'):
[_summarize_var_or_coord(lname,
relevant_coord.get_level_variable(lname),
col_width, marker=marker)
for lname in coord.level_names])
for lname in coord.all_level_names])


def _not_remote(var):
Expand Down
149 changes: 147 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def convert_label_indexer(index, label, index_name='', method=None,
indexer, new_index = index.get_loc_level(
label, level=list(range(len(label)))
)

else:
label = _asarray_tuplesafe(label)
if label.ndim == 0:
Expand All @@ -230,6 +229,30 @@ def convert_label_indexer(index, label, index_name='', method=None,
return indexer, new_index


def get_dim_pos_indexers(data_obj, indexers):
"""
Given a xarray data object and position based indexers, return a mapping
of position-indexers with only dimension names as keys.
It also checks if the multiIndex has only a single levels.
"""
invalid = [k for k in indexers
if k not in data_obj.dims and k not in data_obj._level_coords]
if invalid:
raise ValueError("dimensions or multi-index levels %r do not exist"
% invalid)

dim_indexers = {}
for key, label in iteritems(indexers):
dim, = data_obj[key].dims
if key != dim:
# assume here multi-index level indexer
if len(data_obj.variables[dim].level_names) == 1:
# valid only for 1 level case.
dim_indexers[dim] = label
else:
dim_indexers[key] = label
return dim_indexers

def get_dim_indexers(data_obj, indexers):
"""Given a xarray data object and label based indexers, return a mapping
of label indexers with only dimension names as keys.
Expand Down Expand Up @@ -498,7 +521,9 @@ class PandasIndexAdapter(utils.NDArrayMixin):
def __init__(self, array, dtype=None):
self.array = utils.safe_cast_to_index(array)
if dtype is None:
if isinstance(array, pd.PeriodIndex):
if array is None:
self._dtype = None
elif isinstance(array, pd.PeriodIndex):
dtype = np.dtype('O')
elif hasattr(array, 'categories'):
# category isn't a real numpy dtype
Expand Down Expand Up @@ -557,6 +582,126 @@ def __getitem__(self, key):

return result

def __eq__(self, other):
return self.array.equals(other.array)

def __repr__(self):
return ('%s(array=%r, dtype=%r)'
% (type(self).__name__, self.array, self.dtype))


class PandasMultiIndexAdapter(PandasIndexAdapter):
"""
An extension for MultiIndex, which keeps self._scalars indicating the
level names which should be treated as a scalar.
"""
def __init__(self, array, dtype=None, scalars=[]):
super(PandasMultiIndexAdapter, self).__init__(array, dtype)
# If array is 0-dimensional, scalars argument necessary because
# tuple does not have level names
if array.ndim == 0 and len(scalars) == 0:
raise ValueError('Name of levels is necessary for 0d-array input.')

if isinstance(self.array, pd.MultiIndex):
if any([s not in self.all_levels for s in scalars]):
raise ValueError('scalar %s is not a valid level name.')
self._scalars = scalars

@property
def ndim(self):
return 1 if isinstance(self.array, pd.MultiIndex) else 0

@property
def shape(self):
if isinstance(self.array, pd.MultiIndex):
return (len(self.array),)
else:
return ()

@property
def all_levels(self):
""" All level names including scalars"""
# scalar case
if not isinstance(self.array, pd.MultiIndex):
return self.scalars
return self.array.names

@property
def levels(self):
""" Level names except for scalars"""
level_names = list(self.all_levels)
for s in self.scalars:
level_names.remove(s)
return level_names

@property
def scalars(self):
return self._scalars

def __getitem__(self, key):
if isinstance(key, tuple) and len(key) == 1:
# unpack key so it can index a pandas.Index object (pandas.Index
# objects don't like tuples)
key, = key

result = self.array[key]
if isinstance(result, tuple): # if a single item is chosen
result = utils.to_0d_object_array(result)
return PandasMultiIndexAdapter(result, dtype=self.dtype,
scalars=self.all_levels)
return PandasMultiIndexAdapter(result, dtype=self.dtype,
scalars=self.scalars)

def get_level_values(self, level):
"""
Return an index for level-index. In scalar case, return the first item.
"""
if level in self.scalars:
return self.array.get_level_values(level)[0]
elif level in self.levels:
return self.array.get_level_values(level)
else:
raise ValueError('level %r does not exist.' % level)

def set_scalar(self, scalars):
if any([s not in self.levels for s in scalars]):
raise ValueError('scalar %s is not a valid level name.')
# keep scalars in order
new_scalars = []
for l in self.levels:
if l in self._scalars:
new_scalars.append(l)
elif l in scalars:
new_scalars.append(l)
# if all the lebels become scalar, reduce to size 1
if set(new_scalars) == set(self.all_levels):
type(self)(np.array(self.array[0]), self.dtype, new_scalars)
return type(self)(self.array, self.dtype, new_scalars)

def reset_scalar(self, scalars):
if len(scalars) == 0:
return self

level_names = self.all_levels
if len(level_names) == 0:
# in 0d-case, make MultiIndex from a tuple
if any([s not in self.scalars for s in scalars]):
raise ValueError('scalar %s is not a valid level name.')
array = pd.MultiIndex.from_tuples([self.array.item()],
names=self.scalars)
elif any([s not in level_names for s in scalars]):
raise ValueError('scalar %s is not a valid level name.')
else:
array = self.array

new_scalars = self.scalars
for s in scalars:
new_scalars.remove(s)
return type(self)(array, self.dtype, new_scalars)

def __eq__(self, other):
return self.array.equals(other.array) and self.scalars == other.scalars

def __repr__(self):
return ('%s(array=%r, dtype=%r, scalars=%r)'
% (type(self).__name__, self.array, self.dtype, self.scalars))
Loading