From 7fc4e44c82dcc7235bc42931ee307dba0a98a0aa Mon Sep 17 00:00:00 2001 From: Edoardo Federici <49756048+banda-larga@users.noreply.github.com> Date: Wed, 12 Jan 2022 15:27:00 +0100 Subject: [PATCH 1/4] Update run_summarization.py --- .../summarization/run_summarization.py | 63 ++++++++++++++++--- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index c2d0ff87951f21..448f5ead3af6b2 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -37,6 +37,10 @@ AutoTokenizer, DataCollatorForSeq2Seq, HfArgumentParser, + MBart50Tokenizer, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed, @@ -64,6 +68,8 @@ with FileLock(".lock") as lock: nltk.download("punkt", quiet=True) +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] @dataclass class ModelArguments: @@ -114,6 +120,9 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + source_lang: str = field(default=None, metadata={"help": "Source language id for summarization."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for summarization."}) + dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) @@ -217,12 +226,24 @@ class DataTrainingArguments: }, ) source_prefix: Optional[str] = field( - default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") + elif self.source_lang is None or self.target_lang is None: + raise ValueError("Need to specify the source language and the target language.") + else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -369,6 +390,12 @@ def main(): ) model.resize_token_embeddings(len(tokenizer)) + + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") @@ -406,6 +433,26 @@ def main(): logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") return + if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): + assert data_args.target_lang is not None and data_args.source_lang is not None, ( + f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and " + "--target_lang arguments." + ) + + tokenizer.src_lang = data_args.source_lang + tokenizer.tgt_lang = data_args.target_lang + + # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token + # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. + forced_bos_token_id = ( + tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None + ) + model.config.forced_bos_token_id = forced_bos_token_id + + # Get the language codes for input/target. + source_lang = data_args.source_lang.split("_")[0] + target_lang = data_args.target_lang.split("_")[0] + # Get the column names for input/target. dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) if data_args.text_column is None: @@ -436,14 +483,8 @@ def main(): ) def preprocess_function(examples): - - # remove pairs where at least one record is None - inputs, targets = [], [] - for i in range(len(examples[text_column])): - if examples[text_column][i] is not None and examples[summary_column][i] is not None: - inputs.append(examples[text_column][i]) - targets.append(examples[summary_column][i]) - + inputs = examples[text_column] + targets = examples[summary_column] inputs = [prefix + inp for inp in inputs] model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) @@ -636,6 +677,10 @@ def compute_metrics(eval_preds): kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" else: kwargs["dataset"] = data_args.dataset_name + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages if training_args.push_to_hub: trainer.push_to_hub(**kwargs) From bd305a6a5086d23a828edaa7c02d632bab889d7c Mon Sep 17 00:00:00 2001 From: Edoardo Federici <49756048+banda-larga@users.noreply.github.com> Date: Wed, 12 Jan 2022 18:53:37 +0100 Subject: [PATCH 2/4] Fixed languages and added missing code --- .../summarization/run_summarization.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 448f5ead3af6b2..28c8089375522f 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -68,7 +68,7 @@ with FileLock(".lock") as lock: nltk.download("punkt", quiet=True) -# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +# A list of all multilingual tokenizer which require lang attribute. MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] @dataclass @@ -120,8 +120,7 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ - source_lang: str = field(default=None, metadata={"help": "Source language id for summarization."}) - target_lang: str = field(default=None, metadata={"help": "Target language id for summarization."}) + lang: str = field(default=None, metadata={"help": "Language id for summarization."}) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} @@ -241,8 +240,8 @@ class DataTrainingArguments: def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") - elif self.source_lang is None or self.target_lang is None: - raise ValueError("Need to specify the source language and the target language.") + elif self.lang is None: + raise ValueError("Need to specify the language.") else: if self.train_file is not None: @@ -393,9 +392,9 @@ def main(): if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if isinstance(tokenizer, MBartTokenizer): - model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] else: - model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") @@ -434,13 +433,12 @@ def main(): return if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): - assert data_args.target_lang is not None and data_args.source_lang is not None, ( - f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and " - "--target_lang arguments." + assert data_args.lang is not None, ( + f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" ) - tokenizer.src_lang = data_args.source_lang - tokenizer.tgt_lang = data_args.target_lang + tokenizer.src_lang = data_args.lang + tokenizer.tgt_lang = data_args.lang # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. @@ -450,8 +448,8 @@ def main(): model.config.forced_bos_token_id = forced_bos_token_id # Get the language codes for input/target. - source_lang = data_args.source_lang.split("_")[0] - target_lang = data_args.target_lang.split("_")[0] + source_lang = data_args.lang.split("_")[0] + target_lang = data_args.lang.split("_")[0] # Get the column names for input/target. dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) @@ -483,6 +481,14 @@ def main(): ) def preprocess_function(examples): + # remove pairs where at least one record is None + + inputs, targets = [], [] + for i in range(len(examples[text_column])): + if examples[text_column][i] is not None and examples[summary_column][i] is not None: + inputs.append(examples[text_column][i]) + targets.append(examples[summary_column][i]) + inputs = examples[text_column] targets = examples[summary_column] inputs = [prefix + inp for inp in inputs] @@ -678,9 +684,8 @@ def compute_metrics(eval_preds): else: kwargs["dataset"] = data_args.dataset_name - languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] - if len(languages) > 0: - kwargs["language"] = languages + if data_args.lang is not None: + kwargs["language"] = data_args.lang if training_args.push_to_hub: trainer.push_to_hub(**kwargs) From 910a111937e9729a99625b1b0d5c7873f8bce6c1 Mon Sep 17 00:00:00 2001 From: Edoardo Federici <49756048+banda-larga@users.noreply.github.com> Date: Wed, 12 Jan 2022 20:06:57 +0100 Subject: [PATCH 3/4] fixed obj, docs, removed source_lang and target_lang --- examples/pytorch/summarization/run_summarization.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 28c8089375522f..6ef336b388d578 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -231,9 +231,9 @@ class DataTrainingArguments: forced_bos_token: Optional[str] = field( default=None, metadata={ - "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." - "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " - "needs to be the target language token.(Usually it is the target language token)" + "help": "The token to force as the first generated token after the decoder_start_token_id." + "Useful for multilingual models like mBART where the first generated token" + "needs to be the target language token (Usually it is the target language token)" }, ) @@ -447,10 +447,6 @@ def main(): ) model.config.forced_bos_token_id = forced_bos_token_id - # Get the language codes for input/target. - source_lang = data_args.lang.split("_")[0] - target_lang = data_args.lang.split("_")[0] - # Get the column names for input/target. dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) if data_args.text_column is None: From f06d925bcb0048569215d8a9c10eb6d0d2c826a5 Mon Sep 17 00:00:00 2001 From: banda-larga Date: Wed, 12 Jan 2022 22:27:06 +0100 Subject: [PATCH 4/4] make style, run_summarization.py reformatted --- .../pytorch/summarization/run_summarization.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 6ef336b388d578..4e717d8815fd83 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -71,6 +71,7 @@ # A list of all multilingual tokenizer which require lang attribute. MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] + @dataclass class ModelArguments: """ @@ -389,7 +390,7 @@ def main(): ) model.resize_token_embeddings(len(tokenizer)) - + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if isinstance(tokenizer, MBartTokenizer): model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] @@ -433,9 +434,9 @@ def main(): return if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): - assert data_args.lang is not None, ( - f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" - ) + assert ( + data_args.lang is not None + ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" tokenizer.src_lang = data_args.lang tokenizer.tgt_lang = data_args.lang @@ -478,13 +479,13 @@ def main(): def preprocess_function(examples): # remove pairs where at least one record is None - + inputs, targets = [], [] for i in range(len(examples[text_column])): if examples[text_column][i] is not None and examples[summary_column][i] is not None: inputs.append(examples[text_column][i]) targets.append(examples[summary_column][i]) - + inputs = examples[text_column] targets = examples[summary_column] inputs = [prefix + inp for inp in inputs] @@ -679,7 +680,7 @@ def compute_metrics(eval_preds): kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" else: kwargs["dataset"] = data_args.dataset_name - + if data_args.lang is not None: kwargs["language"] = data_args.lang