From 6cf9e695d4f1ab09671da8c4d5e13a8686319172 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Sun, 1 Oct 2023 08:03:02 +0100 Subject: [PATCH] Fix RoBERTa segment ids (#98) * Fox RoBERTa segment ids * Bump version * update CTRL tokenizer pretrained path --- main/Cargo.lock | 2 +- main/Cargo.toml | 2 +- main/src/tokenizer/roberta_tokenizer.rs | 3 ++- python-bindings/tests/test_tokenization_sst2.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/main/Cargo.lock b/main/Cargo.lock index e393199..65d5813 100644 --- a/main/Cargo.lock +++ b/main/Cargo.lock @@ -1104,7 +1104,7 @@ dependencies = [ [[package]] name = "rust_tokenizers" -version = "8.1.0" +version = "8.1.1" dependencies = [ "anyhow", "cached-path", diff --git a/main/Cargo.toml b/main/Cargo.toml index cb0d849..f740a29 100644 --- a/main/Cargo.toml +++ b/main/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust_tokenizers" -version = "8.1.0" +version = "8.1.1" authors = ["Guillaume Becquin "] edition = "2018" description = "High performance tokenizers for Rust" diff --git a/main/src/tokenizer/roberta_tokenizer.rs b/main/src/tokenizer/roberta_tokenizer.rs index ad18674..79e8591 100644 --- a/main/src/tokenizer/roberta_tokenizer.rs +++ b/main/src/tokenizer/roberta_tokenizer.rs @@ -292,7 +292,8 @@ impl Tokenizer for RobertaTokenizer { special_tokens_mask.extend(vec![0; length]); special_tokens_mask.push(1); token_segment_ids.push(0); - token_segment_ids.extend(vec![1; length + 1]); + // RobERTa does not use segment id, the entire sequence is set to zeros. + token_segment_ids.extend(vec![0; length + 1]); output.push(self.vocab.token_to_id(self.vocab.get_sep_value())); output.extend(tokens_ids_with_offsets_2_value.ids); output.push(self.vocab.token_to_id(self.vocab.get_sep_value())); diff --git a/python-bindings/tests/test_tokenization_sst2.py b/python-bindings/tests/test_tokenization_sst2.py index 3c7ad65..c47ec7c 100644 --- a/python-bindings/tests/test_tokenization_sst2.py +++ b/python-bindings/tests/test_tokenization_sst2.py @@ -123,7 +123,7 @@ def test_tokenization_distilbert(self): def test_tokenization_ctrl(self): # Given - self.base_tokenizer = CTRLTokenizer.from_pretrained('ctrl', + self.base_tokenizer = CTRLTokenizer.from_pretrained('Salesforce/ctrl', do_lower_case=True, cache_dir=self.test_dir) self.rust_tokenizer = PyCtrlTokenizer(