Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from pandas.errors import InvalidIndexError
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.cast import sanitize_to_nanoseconds
from pandas.core.dtypes.common import (
is_categorical_dtype,
is_datetime64_dtype,
is_list_like,
is_scalar,
is_timedelta64_dtype,
)

import pandas.core.algorithms as algorithms
Expand Down Expand Up @@ -466,9 +465,6 @@ def __init__(
if isinstance(grouper, (Series, Index)) and name is None:
self.name = grouper.name

if isinstance(grouper, MultiIndex):
self.grouper = grouper._values

# we have a single grouper which may be a myriad of things,
# some of which are dependent on the passing in level

Expand Down Expand Up @@ -506,14 +502,9 @@ def __init__(
self.grouper = grouper._get_grouper()

else:
if self.grouper is None and self.name is not None and self.obj is not None:
self.grouper = self.obj[self.name]

elif isinstance(self.grouper, (list, tuple)):
self.grouper = com.asarray_tuplesafe(self.grouper)

# a passed Categorical
elif is_categorical_dtype(self.grouper):
if is_categorical_dtype(self.grouper):

self.grouper, self.all_grouper = recode_for_groupby(
self.grouper, self.sort, observed
Expand All @@ -539,7 +530,7 @@ def __init__(
)

# we are done
if isinstance(self.grouper, Grouping):
elif isinstance(self.grouper, Grouping):
self.grouper = self.grouper.grouper

# no level passed
Expand All @@ -562,14 +553,10 @@ def __init__(
self.grouper = None # Try for sanity
raise AssertionError(errmsg)

# if we have a date/time-like grouper, make sure that we have
# Timestamps like
if getattr(self.grouper, "dtype", None) is not None:
if is_datetime64_dtype(self.grouper):
self.grouper = self.grouper.astype("datetime64[ns]")
elif is_timedelta64_dtype(self.grouper):

self.grouper = self.grouper.astype("timedelta64[ns]")
if isinstance(self.grouper, np.ndarray):
# if we have a date/time-like grouper, make sure that we have
# Timestamps like
self.grouper = sanitize_to_nanoseconds(self.grouper)

def __repr__(self) -> str:
return f"Grouping({self.name})"
Expand Down Expand Up @@ -876,9 +863,14 @@ def _convert_grouper(axis: Index, grouper):
return grouper._values
else:
return grouper.reindex(axis)._values
elif isinstance(grouper, (list, Series, Index, np.ndarray)):
elif isinstance(grouper, MultiIndex):
return grouper._values
elif isinstance(grouper, (list, tuple, Series, Index, np.ndarray)):
if len(grouper) != len(axis):
raise ValueError("Grouper and axis must be same length")

if isinstance(grouper, (list, tuple)):
grouper = com.asarray_tuplesafe(grouper)
return grouper
else:
return grouper
Expand Down