diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py
index 0301a62f44e..c79033cecd6 100644
--- a/xarray/core/groupby.py
+++ b/xarray/core/groupby.py
@@ -33,7 +33,6 @@
     safe_cast_to_index,
 )
 from xarray.core.options import _get_keep_attrs
-from xarray.core.pycompat import integer_types
 from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray
 from xarray.core.utils import (
     either_dict_or_kwargs,
@@ -1296,7 +1295,11 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray:
         return ops.where_method(self, cond, other)
 
     def _first_or_last(self, op, skipna, keep_attrs):
-        if isinstance(self._group_indices[0], integer_types):
+        if all(
+            isinstance(maybe_slice, slice)
+            and (maybe_slice.stop == maybe_slice.start + 1)
+            for maybe_slice in self._group_indices
+        ):
             # NB. this is currently only used for reductions along an existing
             # dimension
             return self._obj
diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py
index 6f48b6e11bc..07e2233c4fd 100644
--- a/xarray/tests/test_groupby.py
+++ b/xarray/tests/test_groupby.py
@@ -1099,14 +1099,14 @@ def test_stack_groupby_unsorted_coord(self):
         y_vals = [2, 3]
 
         arr = xr.DataArray(data, dims=dims, coords={"y": y_vals})
-        actual1 = arr.stack(z=dims).groupby("z", squeeze=False).first()
+        actual1 = arr.stack(z=dims).groupby("z").first()
         midx1 = pd.MultiIndex.from_product([[0, 1], [2, 3]], names=dims)
         expected1 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx1})
         assert_equal(actual1, expected1)
 
         # GH: 3287.  Note that y coord values are not in sorted order.
         arr = xr.DataArray(data, dims=dims, coords={"y": y_vals[::-1]})
-        actual2 = arr.stack(z=dims).groupby("z", squeeze=False).first()
+        actual2 = arr.stack(z=dims).groupby("z").first()
         midx2 = pd.MultiIndex.from_product([[0, 1], [3, 2]], names=dims)
         expected2 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx2})
         assert_equal(actual2, expected2)
@@ -1420,7 +1420,7 @@ def test_groupby_first_and_last(self):
         actual = array.groupby(by).first()
         assert_identical(expected, actual)
 
-        actual = array.groupby("x", squeeze=False).first()
+        actual = array.groupby("x").first()
         expected = array  # should be a no-op
         assert_identical(expected, actual)