Skip to content

Commit

Permalink
Fixed left_pad_sequence - correctly flip dims based on batch_first (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco authored Sep 8, 2024
1 parent 5d5caca commit 68d4f3e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def test_left_pad_sequence(self):
expected = torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7], [8, 9, 10, 11, 12]])
assert torch.equal(result, expected)

result = left_pad_sequence([a, b, c], batch_first=False, padding_value=0)
expected = torch.tensor(
[[0, 0, 8], [0, 4, 9], [1, 5, 10], [2, 6, 11], [3, 7, 12]]
)
assert torch.equal(result, expected)


class TestPaddedCollate:
def test_padded_collate_classifier_labels(self):
Expand Down
2 changes: 1 addition & 1 deletion torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def left_pad_sequence(
map(lambda x: torch.flip(x, dims=[0]), sequences),
batch_first=batch_first,
padding_value=padding_value,
).flip(dims=[1])
).flip(dims=[int(batch_first)])


def padded_collate(
Expand Down

0 comments on commit 68d4f3e

Please sign in to comment.