Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mBART support for run_summarization.py #15125

Merged
merged 4 commits into from
Jan 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions examples/pytorch/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
Expand Down Expand Up @@ -64,6 +68,9 @@
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)

# A list of all multilingual tokenizer which require lang attribute.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]


@dataclass
class ModelArguments:
Expand Down Expand Up @@ -114,6 +121,8 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""

lang: str = field(default=None, metadata={"help": "Language id for summarization."})

dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
Expand Down Expand Up @@ -217,12 +226,24 @@ class DataTrainingArguments:
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)

forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the decoder_start_token_id."
"Useful for multilingual models like mBART where the first generated token"
"needs to be the target language token (Usually it is the target language token)"
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
elif self.lang is None:
raise ValueError("Need to specify the language.")

else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
Expand Down Expand Up @@ -370,6 +391,12 @@ def main():

model.resize_token_embeddings(len(tokenizer))

if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)

if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

Expand Down Expand Up @@ -406,6 +433,21 @@ def main():
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return

if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
assert (
data_args.lang is not None
), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"

tokenizer.src_lang = data_args.lang
tokenizer.tgt_lang = data_args.lang

# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = (
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
)
model.config.forced_bos_token_id = forced_bos_token_id

# Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None:
Expand Down Expand Up @@ -436,14 +478,16 @@ def main():
)

def preprocess_function(examples):

# remove pairs where at least one record is None

inputs, targets = [], []
for i in range(len(examples[text_column])):
if examples[text_column][i] is not None and examples[summary_column][i] is not None:
inputs.append(examples[text_column][i])
targets.append(examples[summary_column][i])

inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)

Expand Down Expand Up @@ -637,6 +681,9 @@ def compute_metrics(eval_preds):
else:
kwargs["dataset"] = data_args.dataset_name

if data_args.lang is not None:
kwargs["language"] = data_args.lang

if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
Expand Down