diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 7a5e7fd587ca..781791b5bacc 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -187,6 +187,28 @@ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: else: return token_ids + [self.eos_token_id] + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/src/transformers/tokenization_t5_fast.py b/src/transformers/tokenization_t5_fast.py index e64d8ca7245e..0aba4763dfcd 100644 --- a/src/transformers/tokenization_t5_fast.py +++ b/src/transformers/tokenization_t5_fast.py @@ -191,6 +191,28 @@ def build_inputs_with_special_tokens( token_ids_1 = token_ids_1 + [self.eos_token_id] return self.prefix_tokens + token_ids_0 + token_ids_1 + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) def prepare_seq2seq_batch( self, diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 05d45d9b6936..7ef4b931bf44 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -223,6 +223,20 @@ def test_eos_in_input(self): self.assertEqual(expected_src_tokens, src_ids) self.assertEqual(expected_tgt_tokens, tgt_ids) + def test_token_type_ids(self): + src_text_1 = ["A first paragraph for summarization."] + src_text_2 = ["A second paragraph for summarization."] + + fast_token_type_ids = self.t5_base_tokenizer_fast( + src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True + ).token_type_ids + slow_token_type_ids = self.t5_base_tokenizer( + src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True + ).token_type_ids + + self.assertEqual(slow_token_type_ids, fast_token_type_ids) + self.assertEqual(len(slow_token_type_ids[0]), 18) + def test_fast_and_slow_same_result(self): src_text = " Today is nice day " tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]