diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 41a84888..fbb3d9ee 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1169,27 +1169,24 @@ def __getitem__(self, key, *args, **kwargs): index = np.array([index]) if all(is_array_like(a) for a in [index, output]): - if ( - (len(index) == 1) - and (output.ndim == 1) - and ( - (len(output) > 1) - or isinstance(key, (int, slice)) - or isinstance(key[1], (list, np.ndarray)) - ) - ): - # reshape output of single index to preserve column axis if there are more than one columns being indexed - # or if column key is a list or array + if isinstance(key, tuple): + if ( + len(index) == 1 + and output.ndim == 1 + and not isinstance(key[1], int) + ): + output = output[None, :] + elif ( + (output.ndim == 1) + and isinstance(key[1], (list, np.ndarray)) + and (len(columns) == 1) + ): + # reshape output of single column if column key is a list or array + output = output[:, None] + # if getting a row (1 dim implied) + elif isinstance(key, Number): output = output[None, :] - elif ( - (output.ndim == 1) - and isinstance(key[1], (list, np.ndarray)) - and (len(columns) == 1) - ): - # reshape output of single column if column key is a list or array - output = output[:, None] - kwargs["columns"] = columns kwargs["metadata"] = self._metadata.loc[columns] return _initialize_tsd_output( diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 04948d88..354104ed 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1051,10 +1051,10 @@ def test_vertical_slicing(self, tsdframe, index): "row", [ 0, - # [0, 2], - # slice(20, 30), - # np.hstack([np.zeros(10, bool), True, True, True, np.zeros(87, bool)]), - # np.hstack([np.zeros(10, bool), True, np.zeros(89, bool)]), + [0, 2], + slice(20, 30), + np.hstack([np.zeros(10, bool), True, True, True, np.zeros(87, bool)]), + np.hstack([np.zeros(10, bool), True, np.zeros(89, bool)]), ], ) @pytest.mark.parametrize( @@ -1072,9 +1072,6 @@ def test_vert_and_horz_slicing(self, tsdframe, row, col, expected): if tsdframe.shape[1] == 1: if isinstance(col, list) and isinstance(col[0], int): col = [0] - elif isinstance(col, slice): - col = slice(0, 1) - expected = nap.Tsd elif isinstance(col, list) and isinstance(col[0], bool): col = [col[0]]