Skip to content

Commit 442972c

Browse files
committed
Allow tokenizer to customize stop_tokens
1 parent e4952fb commit 442972c

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

jetstream/engine/token_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def process_result_tokens(
187187
slot_valid = slot_data.valid
188188
slot_lengths = slot_data.lengths
189189
samples, speculations = slot_tokens.shape
190-
stop_tokens = [tokenizer.eos_id, tokenizer.pad_id]
190+
stop_tokens = tokenizer.stop_tokens
191191
# Stop anything which has reached it's max length.
192192
complete = complete | (slot_lengths > slot_max_length)
193193
if debug:
@@ -349,6 +349,11 @@ def bos_id(self) -> int:
349349
"""ID of the BOS token."""
350350
return self.vocab.bos_id
351351

352+
@property
353+
def stop_tokens(self) -> set[int]:
354+
"""ID of the stop token."""
355+
return {self.eos_id, self.pad_id}
356+
352357

353358
class TikToken(tokenizer_api.Tokenizer):
354359
"""Tokenizer to convert strings to token ids and vice-versa."""
@@ -394,6 +399,11 @@ def decode(self, token_ids: list[int]) -> str:
394399
str: String generated from the token ids.
395400
"""
396401
return self.tokenizer.decode(token_ids)
402+
403+
@property
404+
def stop_tokens(self) -> set[int]:
405+
"""ID of the stop token."""
406+
return self.tokenizer.stop_tokens
397407

398408
@property
399409
def pad_id(self) -> int:

jetstream/engine/tokenizer_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,8 @@ def eos_id(self) -> int:
6565
@abc.abstractmethod
6666
def bos_id(self) -> int:
6767
"""ID of BOS token."""
68+
69+
@property
70+
def stop_tokens(self) -> set[int]:
71+
"""ID of the stop token."""
72+
return {self.eos_id, self.pad_id}

0 commit comments

Comments
 (0)