Skip to content

Commit

Permalink
Prevent pad ids, special tokens displaying in generate (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Aug 5, 2024
1 parent 3653c4a commit 8519c35
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 45 deletions.
1 change: 0 additions & 1 deletion recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def _model_generate(
self._model,
context,
max_generated_tokens=self.max_gen_toks,
pad_id=self._tokenizer.pad_id,
temperature=temperature,
top_k=None, # do_sample is not supported currently
stop_tokens=self._tokenizer.stop_tokens,
Expand Down
2 changes: 0 additions & 2 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def generate(self, cfg: DictConfig):
temperature=cfg.temperature,
top_k=cfg.top_k,
stop_tokens=self._tokenizer.stop_tokens,
pad_id=self._tokenizer.pad_id,
custom_generate_next_token=custom_generate_next_token,
)
t = time.perf_counter() - t0
Expand All @@ -169,7 +168,6 @@ def generate(self, cfg: DictConfig):
temperature=cfg.temperature,
top_k=cfg.top_k,
stop_tokens=self._tokenizer.stop_tokens,
pad_id=self._tokenizer.pad_id,
custom_generate_next_token=custom_generate_next_token,
)
t = time.perf_counter() - t0
Expand Down
35 changes: 0 additions & 35 deletions tests/torchtune/utils/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,38 +290,3 @@ def test_stop_tokens_batched_uneven_stopping(
]

assert outputs == expected_output

def test_stop_tokens_batched_uneven_stoppin_with_diff_pad_id(
self, generation_model_batched, prompt_tokens_batched
):
"""
Test to check if the `generate` function produces the right output when stop tokens are
provided, but this time in batched format. This time, seq 0 should hit a stop token before seq 1.
We expect the output to be the length of seq 1, but the first seq should be truncated. This test
also uses a diff pad_id than the default, so we want to make sure it gets applied correctly.
"""
temperature = 0.6
top_k = 100

# This is the first token generated by the model
# so it should stop immediately
stop_tokens = [3987, 3979]

torch.manual_seed(42)

outputs = utils.generate(
model=generation_model_batched,
prompt=prompt_tokens_batched,
max_generated_tokens=10,
pad_id=1,
temperature=temperature,
top_k=top_k,
stop_tokens=stop_tokens,
)

expected_output = [
[2, 3, 4, 5, 6, 7, 8, 9, 3987, 1],
[2, 3, 4, 5, 6, 7, 8, 9, 3958, 3979],
]

assert outputs == expected_output
9 changes: 8 additions & 1 deletion torchtune/models/llama3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def decode(
self,
token_ids: List[int],
truncate_at_eos: bool = True,
skip_special_tokens: bool = True,
) -> str:
"""
Decode a list of token ids into a string.
Expand All @@ -142,11 +143,17 @@ def decode(
token_ids (List[int]): The list of token ids.
truncate_at_eos (bool): Whether to truncate the string at the end of
sequence token. Default is True.
skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string.
Default is True.
Returns:
str: The decoded string.
"""
return self.tt_model.decode(token_ids, truncate_at_eos=truncate_at_eos)
return self.tt_model.decode(
token_ids,
truncate_at_eos=truncate_at_eos,
skip_special_tokens=skip_special_tokens,
)

def _tokenize_header(self, message: Message) -> List[int]:
"""
Expand Down
11 changes: 10 additions & 1 deletion torchtune/modules/tokenizers/_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def decode(
self,
token_ids: List[int],
truncate_at_eos: bool = True,
skip_special_tokens: bool = True,
) -> str:
"""
Decode a list of token ids into a string.
Expand All @@ -146,6 +147,8 @@ def decode(
token_ids (List[int]): The list of token ids.
truncate_at_eos (bool): Whether to truncate the string at the end of
sequence token. Default is True.
skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string.
Default is True.
Returns:
str: The decoded string.
Expand All @@ -157,5 +160,11 @@ def decode(
k = None
if k:
token_ids = token_ids[:k]
token_ids = [token_id for token_id in token_ids if token_id != self.bos_id]
if skip_special_tokens:
token_ids = [
token_id
for token_id in token_ids
if token_id not in self.tt_model._special_tokens.values()
and token_id != self.bos_id
]
return self.tt_model.decode(token_ids)
5 changes: 0 additions & 5 deletions torchtune/utils/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def generate(
prompt: torch.Tensor,
*,
max_generated_tokens: int,
pad_id: int = 0,
temperature: float = 1.0,
top_k: Optional[int] = None,
stop_tokens: Optional[List[int]] = None,
Expand All @@ -80,7 +79,6 @@ def generate(
prompt (torch.Tensor): tensor with the token IDs associated with the given prompt,
with shape either [seq_length] or [bsz x seq_length]
max_generated_tokens (int): number of tokens to be generated
pad_id (int): token ID to use for padding, default 0.
temperature (float): value to scale the predicted logits by, default 1.0.
top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities,
default None.
Expand Down Expand Up @@ -179,8 +177,5 @@ def generate(
# mask out generated tokens in seqs that already hit a stop token
if stop_tokens is not None:
generated_tokens = generated_tokens * stop_token_mask
# if pad_id is not 0, replace 0 with pad_id
if pad_id != 0:
generated_tokens[generated_tokens == 0] = pad_id

return generated_tokens.tolist()

0 comments on commit 8519c35

Please sign in to comment.