diff --git a/tests/torchtune/data/test_collate.py b/tests/torchtune/data/test_collate.py index bbbc4338f1..a1dfbf5a4b 100644 --- a/tests/torchtune/data/test_collate.py +++ b/tests/torchtune/data/test_collate.py @@ -57,6 +57,12 @@ def test_left_pad_sequence(self): expected = torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7], [8, 9, 10, 11, 12]]) assert torch.equal(result, expected) + result = left_pad_sequence([a, b, c], batch_first=False, padding_value=0) + expected = torch.tensor( + [[0, 0, 8], [0, 4, 9], [1, 5, 10], [2, 6, 11], [3, 7, 12]] + ) + assert torch.equal(result, expected) + class TestPaddedCollate: def test_padded_collate_classifier_labels(self): diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index f459b8e249..005e5b9755 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -49,7 +49,7 @@ def left_pad_sequence( map(lambda x: torch.flip(x, dims=[0]), sequences), batch_first=batch_first, padding_value=padding_value, - ).flip(dims=[1]) + ).flip(dims=[int(batch_first)]) def padded_collate(