Skip to content

Commit

Permalink
Allow tokenizer to customize stop_tokens (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored May 17, 2024
1 parent 87aa565 commit 8128c8a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
Empty file added benchmarks/__init__.py
Empty file.
1 change: 1 addition & 0 deletions jetstream/engine/mock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TestVocab(Vocabulary):
eos_id = 1
bos_id = 2
unk_id = 3
stop_tokens = {pad_id, eos_id}
_base_vocab_size = 2**16
tokenizer: TestTokenizer = TestTokenizer()

Expand Down
7 changes: 6 additions & 1 deletion jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def process_result_tokens(
slot_valid = slot_data.valid
slot_lengths = slot_data.lengths
samples, speculations = slot_tokens.shape
stop_tokens = [tokenizer.eos_id, tokenizer.pad_id]
stop_tokens = tokenizer.stop_tokens
# Stop anything which has reached it's max length.
complete = complete | (slot_lengths > slot_max_length)
if debug:
Expand Down Expand Up @@ -395,6 +395,11 @@ def decode(self, token_ids: list[int]) -> str:
"""
return self.tokenizer.decode(token_ids)

@property
def stop_tokens(self) -> set[int]:
"""ID of the stop token."""
return self.tokenizer.stop_tokens

@property
def pad_id(self) -> int:
"""ID of the pad token."""
Expand Down
5 changes: 5 additions & 0 deletions jetstream/engine/tokenizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ def eos_id(self) -> int:
@abc.abstractmethod
def bos_id(self) -> int:
"""ID of BOS token."""

@property
def stop_tokens(self) -> set[int]:
"""ID of the stop token."""
return {self.eos_id, self.pad_id}

0 comments on commit 8128c8a

Please sign in to comment.