diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index a882d3a955469..61fe7ee7a1b39 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -294,6 +294,8 @@ def __getitem__(self, item: PositionalIndexer): ) # We are not an array indexer, so maybe e.g. a slice or integer # indexer. We dispatch to pyarrow. + if type(item) == np.int64: + item = item.item() value = self._data[item] if isinstance(value, pa.ChunkedArray): return type(self)(value) diff --git a/pandas/core/groupby/grouper.py b/pandas/core/groupby/grouper.py index b9f4166b475ca..eec24acd3305c 100644 --- a/pandas/core/groupby/grouper.py +++ b/pandas/core/groupby/grouper.py @@ -733,6 +733,11 @@ def get_grouper( """ group_axis = obj._get_axis(axis) + tuple_unified = False + if isinstance(key, list): + if len(key) == 1 and isinstance(key[0], str): + tuple_unified = True + # validate that the passed single level is compatible with the passed # axis of the object if level is not None: @@ -918,7 +923,12 @@ def is_in_obj(gpr) -> bool: # create the internals grouper grouper = ops.BaseGrouper( - group_axis, groupings, sort=sort, mutated=mutated, dropna=dropna + group_axis, + groupings, + tuple_unified=tuple_unified, + sort=sort, + mutated=mutated, + dropna=dropna, ) return grouper, frozenset(exclusions), obj diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 6dc4ccfa8e1ee..fd72a61065404 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -711,6 +711,7 @@ def __init__( self, axis: Index, groupings: Sequence[grouper.Grouping], + tuple_unified: bool = False, sort: bool = True, group_keys: bool = True, mutated: bool = False, @@ -721,6 +722,7 @@ def __init__( self.axis = axis self._groupings: list[grouper.Grouping] = list(groupings) + self.tuple_unified = tuple_unified self._sort = sort self.group_keys = group_keys self.mutated = mutated @@ -779,13 +781,13 @@ def _get_grouper(self): @final @cache_readonly def group_keys_seq(self): - if len(self.groupings) == 1: + if len(self.groupings) == 1 and self.tuple_unified is False: return self.levels[0] - else: - ids, _, ngroups = self.group_info - # provide "flattened" iterator for multi-group setting - return get_flattened_list(ids, ngroups, self.levels, self.codes) + ids, _, ngroups = self.group_info + + # provide "flattened" iterator for multi-group setting + return get_flattened_list(ids, ngroups, self.levels, self.codes) @final def apply( @@ -1123,12 +1125,13 @@ def __init__( binlabels, mutated: bool = False, indexer=None, + tuple_unified: bool = False, ) -> None: self.bins = ensure_int64(bins) self.binlabels = ensure_index(binlabels) self.mutated = mutated self.indexer = indexer - + self.tuple_unified = False # These lengths must match, otherwise we could call agg_series # with empty self.bins, which would raise in libreduction. assert len(self.binlabels) == len(self.bins) diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index 03aad0ef64dec..e8e4598d1472a 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -161,6 +161,9 @@ def __internal_pivot_table( pass values = list(values) + if isinstance(keys, list): + if len(keys) == 1: + keys = keys[0] grouped = data.groupby(keys, observed=observed, sort=sort) agged = grouped.agg(aggfunc) if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns): @@ -367,7 +370,11 @@ def _all_key(key): margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc) cat_axis = 1 - for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed): + for keys, piece in table.groupby(level=0, axis=cat_axis, observed=observed): + if isinstance(keys, tuple): + (key,) = keys + else: + key = keys all_key = _all_key(key) # we are going to mutate this, so need to copy! diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index ee7493813f13a..fe75f552c6633 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -175,7 +175,8 @@ def __init__( # For `hist` plot, need to get grouped original data before `self.data` is # updated later if self.by is not None and self._kind == "hist": - self._grouped = data.groupby(self.by) + bymodi = fix_groupby_singlelist_input(by) + self._grouped = data.groupby(bymodi) self.kind = kind @@ -1832,3 +1833,10 @@ def blank_labeler(label, value): leglabels = labels if labels is not None else idx for p, l in zip(patches, leglabels): self._append_legend_handles_labels(p, l) + + +def fix_groupby_singlelist_input(keys): + if isinstance(keys, list): + if len(keys) == 1 and isinstance(keys[0], str): + keys = keys[0] + return keys diff --git a/pandas/plotting/_matplotlib/groupby.py b/pandas/plotting/_matplotlib/groupby.py index 4f1cd3f38343a..0c87db697b342 100644 --- a/pandas/plotting/_matplotlib/groupby.py +++ b/pandas/plotting/_matplotlib/groupby.py @@ -108,7 +108,8 @@ def reconstruct_data_with_by( 1 3.0 4.0 NaN NaN 2 NaN NaN 5.0 6.0 """ - grouped = data.groupby(by) + bymodi = fix_groupby_singlelist_input(by) + grouped = data.groupby(bymodi) data_list = [] for key, group in grouped: @@ -134,3 +135,10 @@ def reformat_hist_y_given_by( if by is not None and len(y.shape) > 1: return np.array([remove_na_arraylike(col) for col in y.T]).T return remove_na_arraylike(y) + + +def fix_groupby_singlelist_input(keys): + if isinstance(keys, list): + if len(keys) == 1 and isinstance(keys[0], str): + keys = keys[0] + return keys diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index 3b151d67c70be..61408d7c946c9 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -67,7 +67,8 @@ def _args_adjust(self): # where subplots are created based on by argument if is_integer(self.bins): if self.by is not None: - grouped = self.data.groupby(self.by)[self.columns] + bymodi = fix_groupby_singlelist_input(self.by) + grouped = self.data.groupby(bymodi)[self.columns] self.bins = [self._calculate_bins(group) for key, group in grouped] else: self.bins = self._calculate_bins(self.data) @@ -271,6 +272,8 @@ def _grouped_plot( grouped = data.groupby(by) if column is not None: grouped = grouped[column] + if isinstance(by, list) and len(by) == 1: + by = [by] naxes = len(grouped) fig, axes = create_subplots( @@ -528,3 +531,10 @@ def hist_frame( maybe_adjust_figure(fig, wspace=0.3, hspace=0.3) return axes + + +def fix_groupby_singlelist_input(keys): + if isinstance(keys, list): + if len(keys) == 1 and isinstance(keys[0], str): + keys = keys[0] + return keys diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 920b869ef799b..0ce73c73dd2a4 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -2795,3 +2795,21 @@ def test_groupby_none_column_name(): result = df.groupby(by=[None]).sum() expected = DataFrame({"b": [2, 5], "c": [9, 13]}, index=Index([1, 2], name=None)) tm.assert_frame_equal(result, expected) + + +def test_groupby_iterator_one_grouper(): + df = DataFrame(columns=["a", "b", "c"], index=["x", "y"]) + df.loc["y"] = Series({"a": 1, "b": 5, "c": 2}) + expected = True + + values, _ = next(iter(df.groupby(["a", "b"]))) + result = isinstance(values, tuple) + assert result == expected + + values, _ = next(iter(df.groupby(["a"]))) + result = isinstance(values, tuple) + assert result == expected + + values, _ = next(iter(df.groupby("a"))) + result = isinstance(values, int) + assert result == expected diff --git a/pandas/tests/reshape/merge/test_join.py b/pandas/tests/reshape/merge/test_join.py index 905c2af2d22a5..ee460eb365d25 100644 --- a/pandas/tests/reshape/merge/test_join.py +++ b/pandas/tests/reshape/merge/test_join.py @@ -718,7 +718,9 @@ def _check_join(left, right, result, join_col, how="left", lsuffix="_x", rsuffix # some smoke tests for c in join_col: assert result[c].notna().all() - + if isinstance(join_col, list): + if len(join_col) == 1: + join_col = join_col[0] left_grouped = left.groupby(join_col) right_grouped = right.groupby(join_col)