Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the prepare_seq2seq_batch function to ProphetNet #8515

Merged
4 changes: 4 additions & 0 deletions src/transformers/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,10 @@ def forward(
logits = predict_logits[:, 0]
logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None

# To use .view in loss computation, make sure that logits is contiguous.
if not logits.is_contiguous():
logits = logits.contiguous()

loss = None
if labels is not None:
loss = self._compute_loss(predict_logits, labels)
Expand Down
44 changes: 43 additions & 1 deletion src/transformers/tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import os
from typing import List, Optional, Tuple

from .file_utils import add_start_docstrings
from .tokenization_bert import BasicTokenizer, WordpieceTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .utils import logging


Expand Down Expand Up @@ -286,3 +288,43 @@ def build_inputs_with_special_tokens(
return token_ids_0 + [self.sep_token_id]
sep = [self.sep_token_id]
return token_ids_0 + sep + token_ids_1 + sep

@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.max_len
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels_and_decoder_mask = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
return model_inputs
25 changes: 24 additions & 1 deletion tests/test_tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import os
import unittest

from transformers.testing_utils import slow
from transformers import BatchEncoding
from transformers.testing_utils import require_torch, slow
from transformers.tokenization_bert import (
BasicTokenizer,
WordpieceTokenizer,
Expand Down Expand Up @@ -150,6 +151,28 @@ def test_wordpiece_tokenizer(self):

self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])

@require_torch
def test_prepare_seq2seq_batch(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased")

src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102]
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
return_tensors="pt",
)
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)

self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)

def test_is_whitespace(self):
self.assertTrue(_is_whitespace(" "))
self.assertTrue(_is_whitespace("\t"))
Expand Down