Skip to content

Commit

Permalink
[cherry-pick] Ensure actual lenght of transformed seq by T5 Transform…
Browse files Browse the repository at this point in the history
… cannot be longer than specified max_len (#2143)

* Ensure max seq length is not 1 > actual max seq len

Summary: Currently a "bug" wherein max length can actually be 1 greater than actual max_seq length specified b/c we add EOS token.

Reviewed By: yohann-benchetrit

Differential Revision: D44708339

fbshipit-source-id: dc24d268caa4eb783df8ac7e8a2463e33d3244d9

* Format new changes

---------

Co-authored-by: Joe Cummings <jrcummings@meta.com>
  • Loading branch information
joecummings and Joe Cummings authored Apr 6, 2023
1 parent df95462 commit fdc4858
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchtext/models/t5/t5_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, sp_model_path: str, max_seq_len: int, eos_idx: int, padding_i
self.max_seq_len = max_seq_len
self.eos_idx = eos_idx
self.padding_idx = padding_idx
self.pipeline = T.Sequential(T.Truncate(self.max_seq_len), T.AddToken(token=self.eos_idx, begin=False))
self.pipeline = T.Sequential(T.Truncate(self.max_seq_len - 1), T.AddToken(token=self.eos_idx, begin=False))

def forward(self, input: Union[str, List[str]]) -> torch.Tensor:
"""
Expand Down

0 comments on commit fdc4858

Please sign in to comment.