Skip to content

Commit

Permalink
fix: support hf transformers with cls_token_id and sep_token_id set t…
Browse files Browse the repository at this point in the history
…o None
  • Loading branch information
percevalw committed Nov 22, 2024
1 parent ff8bd41 commit b3ed86e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Fixed

- Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor
- Support huggingface transformers that do not set `cls_token_id` and `sep_token_id` (we now also look for these tokens in the `special_tokens_map` and `vocab` mappings)

## v0.14.0 (2024-11-14)

Expand Down
22 changes: 14 additions & 8 deletions edsnlp/pipes/trainable/embeddings/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ def __init__(
self.max_tokens_per_device = max_tokens_per_device
self._mem_per_unit = None
self.span_getter = span_getter
self.cls_token_id = self.tokenizer.cls_token_id
self.sep_token_id = self.tokenizer.sep_token_id
if self.cls_token_id is None:
[self.cls_token_id] = self.tokenizer.convert_tokens_to_ids(
[self.tokenizer.special_tokens_map["bos_token"]]
)
if self.sep_token_id is None:
[self.sep_token_id] = self.tokenizer.convert_tokens_to_ids(
[self.tokenizer.special_tokens_map["eos_token"]]
)

if new_tokens:
self.tokenizer.add_tokens(sorted(set(t[1] for t in new_tokens)))
Expand Down Expand Up @@ -364,11 +374,9 @@ def collate(self, batch):
sample_word_lengths,
sample_word_tokens,
):
prompt_input_ids = [self.tokenizer.cls_token_id]
prompt_input_ids = [self.cls_token_id]
if span_prompt_input_ids:
prompt_input_ids.extend(
[*span_prompt_input_ids, self.tokenizer.sep_token_id]
)
prompt_input_ids.extend([*span_prompt_input_ids, self.sep_token_id])
windows_offsets = list(
range(0, max(len(span_text_input_ids) - overlap, 1), stride)
)
Expand All @@ -379,9 +387,7 @@ def collate(self, batch):
offset : offset + self.window
]
window_input_ids = (
prompt_input_ids
+ window_text_input_ids
+ [self.tokenizer.sep_token_id]
prompt_input_ids + window_text_input_ids + [self.sep_token_id]
)
left_overlap = overlap // 2 if offset > 0 else 0
right_overlap = (
Expand Down Expand Up @@ -523,7 +529,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput:
# )
word_embeddings = torch.nn.functional.embedding_bag(
input=batch["word_indices"],
weight=wordpiece_embeddings.view(-1, wordpiece_embeddings.size(2)),
weight=wordpiece_embeddings.reshape(-1, wordpiece_embeddings.size(2)),
offsets=batch["word_offsets"],
)
word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding
Expand Down

0 comments on commit b3ed86e

Please sign in to comment.