From 90850b38e888e330cefd0569223c7ee07be8d515 Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Mon, 1 Nov 2021 20:10:50 -0400 Subject: [PATCH 01/11] add mt5 model --- summertime/model/__init__.py | 2 + summertime/model/single_doc/__init__.py | 1 + .../model/single_doc/multilingual/__init__.py | 1 + .../multilingual/base_multilingual_model.py | 12 ++- .../single_doc/multilingual/mbart_model.py | 5 +- .../single_doc/multilingual/mt5_model.py | 81 +++++++++++++++++++ 6 files changed, 96 insertions(+), 6 deletions(-) create mode 100644 summertime/model/single_doc/multilingual/mt5_model.py diff --git a/summertime/model/__init__.py b/summertime/model/__init__.py index ef3c3810..45b083ed 100644 --- a/summertime/model/__init__.py +++ b/summertime/model/__init__.py @@ -5,6 +5,7 @@ LongformerModel, PegasusModel, TextRankModel, + MT5Model, ) from .multi_doc import MultiDocJointModel, MultiDocSeparateModel from .dialogue import HMNetModel, FlattenDialogueModel @@ -14,6 +15,7 @@ SUPPORTED_SUMM_MODELS = [ BartModel, MBartModel, + MT5Model, LexRankModel, LongformerModel, PegasusModel, diff --git a/summertime/model/single_doc/__init__.py b/summertime/model/single_doc/__init__.py index 7b8103e4..91bb0a1d 100644 --- a/summertime/model/single_doc/__init__.py +++ b/summertime/model/single_doc/__init__.py @@ -5,3 +5,4 @@ from .textrank_model import TextRankModel from .multilingual import MBartModel +from .multilingual import MT5Model diff --git a/summertime/model/single_doc/multilingual/__init__.py b/summertime/model/single_doc/multilingual/__init__.py index e496f090..f5ae191f 100644 --- a/summertime/model/single_doc/multilingual/__init__.py +++ b/summertime/model/single_doc/multilingual/__init__.py @@ -1 +1,2 @@ from .mbart_model import MBartModel +from .mt5_model import MT5Model diff --git a/summertime/model/single_doc/multilingual/base_multilingual_model.py b/summertime/model/single_doc/multilingual/base_multilingual_model.py index 5f9158bf..2f3b7867 100644 --- a/summertime/model/single_doc/multilingual/base_multilingual_model.py +++ b/summertime/model/single_doc/multilingual/base_multilingual_model.py @@ -48,11 +48,15 @@ def assert_summ_input_language(cls, corpus, query): label = label.replace("__label__", "") if label in cls.lang_tag_dict: - print(f"Language '{label}' detected.") + print(f"Supported language '{label}' detected.") return cls.lang_tag_dict[label] else: raise ValueError( - f"Unsupported language '{label}'' detected!\n\ - Try checking if another of our multilingual models \ - supports this language." + f"Unsupported language '{label}' detected! \ +Try checking if another of our multilingual models \ +supports this language." ) + + @classmethod + def get_supported_languages(cls): #TODO: implement a display of supported languages for all models? + return cls.lang_tag_dict.keys() \ No newline at end of file diff --git a/summertime/model/single_doc/multilingual/mbart_model.py b/summertime/model/single_doc/multilingual/mbart_model.py index 1cfe683e..450703ae 100644 --- a/summertime/model/single_doc/multilingual/mbart_model.py +++ b/summertime/model/single_doc/multilingual/mbart_model.py @@ -94,7 +94,7 @@ def summarize(self, corpus, queries=None): ).to(self.device) encoded_summaries = self.model.generate( **batch, - decoder_start_token_id=self.tokenizer.lang_code_to_id[lang_code], + forced_bos_token_id=self.tokenizer.lang_code_to_id[lang_code], length_penalty=1.0, num_beams=4, early_stopping=True, @@ -113,9 +113,10 @@ def show_capability(cls) -> None: "Introduced in 2020, a multilingual variant of BART (a large neural model) " "trained on web crawl data.\n" "Strengths: \n - Multilinguality: supports 50 different languages\n" + " - Higher max input length than mT5 (1024)" "Weaknesses: \n - High memory usage" "Initialization arguments: \n " "- `device = 'cpu'` specifies the device the model is stored on and uses for computation. " - "Use `device='gpu'` to run on an Nvidia GPU." + "Use `device='cuda'` to run on an Nvidia GPU." ) print(f"{basic_description} \n {'#'*20} \n {more_details}") diff --git a/summertime/model/single_doc/multilingual/mt5_model.py b/summertime/model/single_doc/multilingual/mt5_model.py new file mode 100644 index 00000000..6c00c52c --- /dev/null +++ b/summertime/model/single_doc/multilingual/mt5_model.py @@ -0,0 +1,81 @@ +from transformers import MT5ForConditionalGeneration, MT5Tokenizer +from .base_multilingual_model import MultilingualSummModel + + + +class MT5Model(MultilingualSummModel): + """ + MT5 Model for Multilingual Summarization + """ + # static variables + model_name = 'mT5' + is_extractive = False + is_neural = True + is_multilingual = True + + lang_tag_dict = { + "am": "am", "ar": "ar", + "az": "az", "bn": "bn", + "my": "my", "zh-CN": "zh-CN", + "zh-TW": "zh-TW", "en": "en", + "fr": "fr", "gu": "gu", + "ha": "ha", "hi": "hi", + "ig": "ig", "id": "id", + "ja": "ja", "rn": "rn", + "ko": "ko", "ky": "ky", + "mr": "mr", "np": "np", + "om": "om", "ps": "ps", + "fa": "fa", "pt": "pt", # missing pidgin from XLSum--does not have ISO 639-1 code + "pa": "pa", "ru": "ru", + "gd": "gd", "sr": "sr", + "si": "si", "so": "so", + "es": "es", "sw": "sw", + "ta": "ta", "te": "te", + "th": "th", "ti": "ti", + "tr": "tr", "uk": "uk", + "ur": "ur", "uz": "uz", + "vi": "vi", "cy": "cy", + "yo": "yo" + } #TODO: add supported langs from mT5 that are not in the XLSum dataset (not finetuned on) + + def __init__(self, device="cpu"): + + super(MT5Model, self).__init__( + trained_domain="News", + max_input_length=512, + max_output_length=None, + ) + + self.device = device + + model_name = "csebuetnlp/mT5_multilingual_XLSum" + self.tokenizer = MT5Tokenizer.from_pretrained(model_name) + self.model = MT5ForConditionalGeneration.from_pretrained(model_name).to(device) + + def summarize(self, corpus, queries=None): + self.assert_summ_input_type(corpus, queries) + + lang_code = self.assert_summ_input_language(corpus, queries) + + with self.tokenizer.as_target_tokenizer(): + batch = self.tokenizer(corpus, truncation=True, padding="longest", max_length=self.max_input_length, return_tensors="pt").to(self.device) + + encoded_summaries = self.model.generate(**batch, num_beams=4, length_penalty=1.0, early_stopping=True) + + summaries = self.tokenizer.batch_decode(encoded_summaries, skip_special_tokens=True) + + return summaries + + @classmethod + def show_capability(cls) -> None: + basic_description = cls.generate_basic_description() + more_details = ( + "Introduced in ____, a massively multilingual variant of Google's T5, a large neural model. " + "Trained on web crawled data and fine-tuned on XLSum, a 45-language multilingual news dataset.\n" + "Strengths: \n - Massively multilingual: supports 101 different languages\n" + "Weaknesses: \n - High memory usage\n - Lower max input length (512)" + "Initialization arguments: \n " + "- `device = 'cpu'` specifies the device the model is stored on and uses for computation. " + "Use `device='cuda'` to run on an Nvidia GPU." + ) + print(f"{basic_description} \n {'#'*20} \n {more_details}") From e59493813ab18bb54c8ffcb1e3747c00c5f0f4aa Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Mon, 1 Nov 2021 21:15:31 -0400 Subject: [PATCH 02/11] reformatting --- .../multilingual/base_multilingual_model.py | 6 +- .../single_doc/multilingual/mt5_model.py | 93 ++++++++++++------- 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/summertime/model/single_doc/multilingual/base_multilingual_model.py b/summertime/model/single_doc/multilingual/base_multilingual_model.py index 2f3b7867..3fbfa5a9 100644 --- a/summertime/model/single_doc/multilingual/base_multilingual_model.py +++ b/summertime/model/single_doc/multilingual/base_multilingual_model.py @@ -58,5 +58,7 @@ def assert_summ_input_language(cls, corpus, query): ) @classmethod - def get_supported_languages(cls): #TODO: implement a display of supported languages for all models? - return cls.lang_tag_dict.keys() \ No newline at end of file + def get_supported_languages( + cls, + ): # TODO: implement a display of supported languages for all models? + return cls.lang_tag_dict.keys() diff --git a/summertime/model/single_doc/multilingual/mt5_model.py b/summertime/model/single_doc/multilingual/mt5_model.py index 6c00c52c..c59056e3 100644 --- a/summertime/model/single_doc/multilingual/mt5_model.py +++ b/summertime/model/single_doc/multilingual/mt5_model.py @@ -2,44 +2,65 @@ from .base_multilingual_model import MultilingualSummModel - class MT5Model(MultilingualSummModel): """ MT5 Model for Multilingual Summarization """ + # static variables - model_name = 'mT5' + model_name = "mT5" is_extractive = False is_neural = True is_multilingual = True lang_tag_dict = { - "am": "am", "ar": "ar", - "az": "az", "bn": "bn", - "my": "my", "zh-CN": "zh-CN", - "zh-TW": "zh-TW", "en": "en", - "fr": "fr", "gu": "gu", - "ha": "ha", "hi": "hi", - "ig": "ig", "id": "id", - "ja": "ja", "rn": "rn", - "ko": "ko", "ky": "ky", - "mr": "mr", "np": "np", - "om": "om", "ps": "ps", - "fa": "fa", "pt": "pt", # missing pidgin from XLSum--does not have ISO 639-1 code - "pa": "pa", "ru": "ru", - "gd": "gd", "sr": "sr", - "si": "si", "so": "so", - "es": "es", "sw": "sw", - "ta": "ta", "te": "te", - "th": "th", "ti": "ti", - "tr": "tr", "uk": "uk", - "ur": "ur", "uz": "uz", - "vi": "vi", "cy": "cy", - "yo": "yo" - } #TODO: add supported langs from mT5 that are not in the XLSum dataset (not finetuned on) + "am": "am", + "ar": "ar", + "az": "az", + "bn": "bn", + "my": "my", + "zh-CN": "zh-CN", + "zh-TW": "zh-TW", + "en": "en", + "fr": "fr", + "gu": "gu", + "ha": "ha", + "hi": "hi", + "ig": "ig", + "id": "id", + "ja": "ja", + "rn": "rn", + "ko": "ko", + "ky": "ky", + "mr": "mr", + "np": "np", + "om": "om", + "ps": "ps", + "fa": "fa", + "pt": "pt", # missing pidgin from XLSum--does not have ISO 639-1 code + "pa": "pa", + "ru": "ru", + "gd": "gd", + "sr": "sr", + "si": "si", + "so": "so", + "es": "es", + "sw": "sw", + "ta": "ta", + "te": "te", + "th": "th", + "ti": "ti", + "tr": "tr", + "uk": "uk", + "ur": "ur", + "uz": "uz", + "vi": "vi", + "cy": "cy", + "yo": "yo", + } # TODO: add supported langs from mT5 that are not in the XLSum dataset (not finetuned on) def __init__(self, device="cpu"): - + super(MT5Model, self).__init__( trained_domain="News", max_input_length=512, @@ -55,14 +76,24 @@ def __init__(self, device="cpu"): def summarize(self, corpus, queries=None): self.assert_summ_input_type(corpus, queries) - lang_code = self.assert_summ_input_language(corpus, queries) + self.assert_summ_input_language(corpus, queries) with self.tokenizer.as_target_tokenizer(): - batch = self.tokenizer(corpus, truncation=True, padding="longest", max_length=self.max_input_length, return_tensors="pt").to(self.device) - - encoded_summaries = self.model.generate(**batch, num_beams=4, length_penalty=1.0, early_stopping=True) + batch = self.tokenizer( + corpus, + truncation=True, + padding="longest", + max_length=self.max_input_length, + return_tensors="pt", + ).to(self.device) - summaries = self.tokenizer.batch_decode(encoded_summaries, skip_special_tokens=True) + encoded_summaries = self.model.generate( + **batch, num_beams=4, length_penalty=1.0, early_stopping=True + ) + + summaries = self.tokenizer.batch_decode( + encoded_summaries, skip_special_tokens=True + ) return summaries From 1415080ab51b68e5f2c2c6c6311c7901ad08c7c8 Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Mon, 8 Nov 2021 16:06:34 -0500 Subject: [PATCH 03/11] add rest of mt5 languages to dict --- .../single_doc/multilingual/mt5_model.py | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/summertime/model/single_doc/multilingual/mt5_model.py b/summertime/model/single_doc/multilingual/mt5_model.py index c59056e3..ad34728f 100644 --- a/summertime/model/single_doc/multilingual/mt5_model.py +++ b/summertime/model/single_doc/multilingual/mt5_model.py @@ -56,8 +56,72 @@ class MT5Model(MultilingualSummModel): "uz": "uz", "vi": "vi", "cy": "cy", - "yo": "yo", - } # TODO: add supported langs from mT5 that are not in the XLSum dataset (not finetuned on) + "yo": "yo", # <- up to here: langs included in XLSum + "af": "af", + "sq": "sq", + "hy": "hy", + "eu": "eu", + "be": "be", + "bg": "bg", + "ca": "ca", + # cebuano has no ISO-639-1 code + "ceb": "ceb", + "ny": "ny", + "co": "co", + "cs": "cs", + "da": "da", + "nl": "nl", + "eo": "eo", + "et": "et", + "tl": "tl", # tagalog in place of filipino + "fi": "fi", + "gl": "gl", + "ka": "ka", + "de": "de", + "el": "el", + "ht": "ht", + "haw": "haw", # hawaiian 639-3 code (not in fasttext id) + "he": "he", + "hmn": "hmn", # hmong 639-3 code (not in fasttext id) + "hu": "hu", + "is": "is", + "ga": "ga", + "it": "it", + "jv": "jv", + "kn": "kn", + "kk": "kk", + "km": "km", + "ku": "ku", + "lo": "lo", + "la": "la", + "lv": "lv", + "lt": "lt", + "lb": "lb", + "mk": "mk", + "mg": "mg", + "ms": "ms", + "ml": "ml", + "mt": "mt", + "mi": "mi", + "mn": "mn", + "ne": "ne", + "no": "no", + "pl": "pl", + "ro": "ro", + "sm": "sm", + "sn": "sn", + "sd": "sd", + "sk": "sk", + "sl": "sl", + "st": "st", + "su": "su", + "sv": "sv", + "tg": "tg", + "fy": "fy", + "xh": "xh", + "yi": "yi", + "zu": "zu", + } def __init__(self, device="cpu"): From 7e6b81ca1f0748075c0fc471b9127cd7d3f12f80 Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Mon, 8 Nov 2021 16:12:27 -0500 Subject: [PATCH 04/11] use download caching --- .../multilingual/base_multilingual_model.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/summertime/model/single_doc/multilingual/base_multilingual_model.py b/summertime/model/single_doc/multilingual/base_multilingual_model.py index 3fbfa5a9..8a036ff5 100644 --- a/summertime/model/single_doc/multilingual/base_multilingual_model.py +++ b/summertime/model/single_doc/multilingual/base_multilingual_model.py @@ -1,5 +1,5 @@ from summertime.model.single_doc.base_single_doc_model import SingleDocSummModel - +from summertime.util.download_utils import get_cached_file_path, download_with_progressbar import urllib.request import fasttext @@ -22,18 +22,13 @@ def __init__( @classmethod def assert_summ_input_language(cls, corpus, query): - # TODO: add fasttext language detection here - + url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz" - # currently using compressed fasttext model - urllib.request.urlretrieve(url, "lid.176.ftz") - # tqdm( - # urllib.request.urlretrieve(url, "lid.176.ftz"), - # desc="Downloading language detector", - # ) + + filepath = get_cached_file_path("fasttext", "lid.176.ftz", url) classifier = fasttext.load_model( - "./lid.176.ftz" + filepath ) # TODO: change download location, # do not redownload every time if not necessary From d751db8776588cee2e945b110b851202eebc7988 Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Mon, 8 Nov 2021 16:13:54 -0500 Subject: [PATCH 05/11] reformatting --- .../multilingual/base_multilingual_model.py | 11 ++++++----- summertime/model/single_doc/multilingual/mt5_model.py | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/summertime/model/single_doc/multilingual/base_multilingual_model.py b/summertime/model/single_doc/multilingual/base_multilingual_model.py index 8a036ff5..40d3988b 100644 --- a/summertime/model/single_doc/multilingual/base_multilingual_model.py +++ b/summertime/model/single_doc/multilingual/base_multilingual_model.py @@ -1,5 +1,8 @@ from summertime.model.single_doc.base_single_doc_model import SingleDocSummModel -from summertime.util.download_utils import get_cached_file_path, download_with_progressbar +from summertime.util.download_utils import ( + get_cached_file_path, + download_with_progressbar, +) import urllib.request import fasttext @@ -22,14 +25,12 @@ def __init__( @classmethod def assert_summ_input_language(cls, corpus, query): - + url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz" filepath = get_cached_file_path("fasttext", "lid.176.ftz", url) - classifier = fasttext.load_model( - filepath - ) # TODO: change download location, + classifier = fasttext.load_model(filepath) # TODO: change download location, # do not redownload every time if not necessary if all([isinstance(ins, list) for ins in corpus]): diff --git a/summertime/model/single_doc/multilingual/mt5_model.py b/summertime/model/single_doc/multilingual/mt5_model.py index ad34728f..95932a9b 100644 --- a/summertime/model/single_doc/multilingual/mt5_model.py +++ b/summertime/model/single_doc/multilingual/mt5_model.py @@ -56,7 +56,7 @@ class MT5Model(MultilingualSummModel): "uz": "uz", "vi": "vi", "cy": "cy", - "yo": "yo", # <- up to here: langs included in XLSum + "yo": "yo", # <- up to here: langs included in XLSum "af": "af", "sq": "sq", "hy": "hy", @@ -73,16 +73,16 @@ class MT5Model(MultilingualSummModel): "nl": "nl", "eo": "eo", "et": "et", - "tl": "tl", # tagalog in place of filipino + "tl": "tl", # tagalog in place of filipino "fi": "fi", "gl": "gl", "ka": "ka", "de": "de", "el": "el", "ht": "ht", - "haw": "haw", # hawaiian 639-3 code (not in fasttext id) + "haw": "haw", # hawaiian 639-3 code (not in fasttext id) "he": "he", - "hmn": "hmn", # hmong 639-3 code (not in fasttext id) + "hmn": "hmn", # hmong 639-3 code (not in fasttext id) "hu": "hu", "is": "is", "ga": "ga", @@ -101,7 +101,7 @@ class MT5Model(MultilingualSummModel): "mg": "mg", "ms": "ms", "ml": "ml", - "mt": "mt", + "mt": "mt", "mi": "mi", "mn": "mn", "ne": "ne", From 897ca45d9afc1c71a2f5abd7d28bc64836d33c37 Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Mon, 8 Nov 2021 16:41:29 -0500 Subject: [PATCH 06/11] start on readme edits --- README.md | 20 ++++++++++--------- summertime/model/base_model.py | 10 ++++++++++ .../model/single_doc/base_single_doc_model.py | 4 ++++ .../multilingual/base_multilingual_model.py | 18 +++++++++-------- 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index f177ba53..0b60fb4c 100644 --- a/README.md +++ b/README.md @@ -68,15 +68,17 @@ Also, please run our colab notebook for a more hands-on demo and more examples. ## Models ### Supported Models -SummerTime supports different models (e.g., TextRank, BART, Longformer) as well as model wrappers for more complex summariztion tasks (e.g., JointModel for multi-doc summarzation, BM25 retrieval for query-based summarization). - -| Models | Single-doc | Multi-doc | Dialogue-based | Query-based | -| --------- | :------------------: | :------------------: | :------------------: | :------------------: | -| BartModel | :heavy_check_mark: | | | | -| BM25SummModel | | | | :heavy_check_mark: | -| HMNetModel | | | :heavy_check_mark: | | -| LexRankModel | :heavy_check_mark: | | | | -| LongformerModel | :heavy_check_mark: | | | | +SummerTime supports different models (e.g., TextRank, BART, Longformer) as well as model wrappers for more complex summarization tasks (e.g., JointModel for multi-doc summarzation, BM25 retrieval for query-based summarization). + +| Models | Single-doc | Multi-doc | Dialogue-based | Query-based | Multilingual | +| --------- | :------------------: | :------------------: | :------------------: | :------------------: | :------------------: | +| BartModel | :heavy_check_mark: | | | | | +| BM25SummModel | | | | :heavy_check_mark: | | +| HMNetModel | | | :heavy_check_mark: | | | +| LexRankModel | :heavy_check_mark: | | | | | +| LongformerModel | :heavy_check_mark: | | | | | +| MBartModel | :heavy_check_mark: | | | | 50 languages (Arabic, Czech, German, English, Spanish, Estonian, Finnish, French, Gujarati, Hindi, Italian, Japanese, Kazakh, Korean, Lithuanian, Latvian, Burmese, Nepali, Dutch, Romanian, Russian, Sinhala, Turkish, Vietnamese, Chinese, Afrikaans, Azerbaijani, Bengali, Persian, Hebrew, Croatian, Indonesian, Georgian, Khmer, Macedonian, Malayalam, Mongolian, Marathi, Polish, Pashto, Portuguese, Swedish, Tamil, Telugu, Thai, Tagalog, Ukrainian, Urdu, Xhosa, Slovenian) | +| MT5Model | :heavy_check_mark: | | | | 101 languages (full list [here](https://github.com/google-research/multilingual-t5#readme)) | | MultiDocJointModel | | :heavy_check_mark: | | | | MultiDocSeparateModel | | :heavy_check_mark: | | | | PegasusModel | :heavy_check_mark: | | | | diff --git a/summertime/model/base_model.py b/summertime/model/base_model.py index 57af59ef..e5740dd8 100644 --- a/summertime/model/base_model.py +++ b/summertime/model/base_model.py @@ -92,6 +92,16 @@ def generate_basic_description(cls) -> str: ) return basic_description + + # TODO nick: implement this function eventually! + # @classmethod + # def show_supported_languages(cls) -> str: + # """ + # Returns a list of supported languages for summarization. + # """ + # raise NotImplementedError( + # "The base class for models shouldn't be instantiated!" + # ) class SummPipeline(SummModel): diff --git a/summertime/model/single_doc/base_single_doc_model.py b/summertime/model/single_doc/base_single_doc_model.py index f6dca9d5..8cd00a5c 100644 --- a/summertime/model/single_doc/base_single_doc_model.py +++ b/summertime/model/single_doc/base_single_doc_model.py @@ -51,3 +51,7 @@ def assert_summ_input_language(cls, corpus, query): print(warning) return "en" # ISO-639-1 code for English + + # @classmethod + # def show_supported_languages(cls) -> str: + # return "english" diff --git a/summertime/model/single_doc/multilingual/base_multilingual_model.py b/summertime/model/single_doc/multilingual/base_multilingual_model.py index 40d3988b..4fa35c1f 100644 --- a/summertime/model/single_doc/multilingual/base_multilingual_model.py +++ b/summertime/model/single_doc/multilingual/base_multilingual_model.py @@ -1,12 +1,11 @@ from summertime.model.single_doc.base_single_doc_model import SingleDocSummModel from summertime.util.download_utils import ( get_cached_file_path, - download_with_progressbar, ) -import urllib.request import fasttext + class MultilingualSummModel(SingleDocSummModel): lang_tag_dict = None @@ -30,7 +29,9 @@ def assert_summ_input_language(cls, corpus, query): filepath = get_cached_file_path("fasttext", "lid.176.ftz", url) - classifier = fasttext.load_model(filepath) # TODO: change download location, + classifier = fasttext.load_model( + str(filepath) + ) # TODO: change download location, # do not redownload every time if not necessary if all([isinstance(ins, list) for ins in corpus]): @@ -53,8 +54,9 @@ def assert_summ_input_language(cls, corpus, query): supports this language." ) - @classmethod - def get_supported_languages( - cls, - ): # TODO: implement a display of supported languages for all models? - return cls.lang_tag_dict.keys() + # @classmethod + # def show_supported_languages( + # cls, + # ): + # langs = [iso639.to_name(lang) for lang in cls.lang_tag_dict.keys()] + # return " ".join(langs) From bf35ca3ac5825e99fb9832af21950bb96c20ce4d Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Mon, 8 Nov 2021 17:01:09 -0500 Subject: [PATCH 07/11] finish first draft of multilingual model documentation --- README.md | 28 ++++++++++++++++++++++++++-- summertime/model/base_model.py | 2 +- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0b60fb4c..0ae75737 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,8 @@ Also, please run our colab notebook for a more hands-on demo and more examples. ## Models ### Supported Models -SummerTime supports different models (e.g., TextRank, BART, Longformer) as well as model wrappers for more complex summarization tasks (e.g., JointModel for multi-doc summarzation, BM25 retrieval for query-based summarization). +SummerTime supports different models (e.g., TextRank, BART, Longformer) as well as model wrappers for more complex summarization tasks (e.g., JointModel for multi-doc summarzation, BM25 retrieval for query-based summarization). Several multilingual models are also supported (mT5 and mBART). + | Models | Single-doc | Multi-doc | Dialogue-based | Query-based | Multilingual | | --------- | :------------------: | :------------------: | :------------------: | :------------------: | :------------------: | @@ -230,7 +231,7 @@ print(corpus) ``` ### Loading a custom dataset -You can use load custom data using the `CustomDataset` class that puts the data in the SummerTime dataset Class +You can usecustom data using the `CustomDataset` class that loads the data in the SummerTime dataset Class ```python from dataset import CustomDataset @@ -294,6 +295,7 @@ train_set = itertools.islice(cnn_dataset.train_set, 5) corpus = [instance.source for instance in train_set] + # Example 1 - traditional non-neural model # LexRank model lexrank = model.LexRankModel(corpus) @@ -422,6 +424,28 @@ query_based_multi_doc_models = assemble_model_pipeline(QMsumDataset) # ] ``` +### Multilingual summarization +The `summarize()` method of multilingual models automatically checks for input document language. + +Single-doc multilingual models can be initialized and used in the same way as monolingual models. They return an error if a language not supported by the model is input. + +```python +mbart_model = st_model.MBartModel() +mt5_model = st_model.MT5Model() + +# load Spanish portion of MLSum dataset +mlsum = datasets.MlsumDataset(["es"]) + +corpus = itertools.islice(mlsum.train_set, 5) +corpus = [instance.source for instance in train_set] + +# mt5 model will automatically detect Spanish as the language and indicate that this is supported! +mt5_model.summarize() +``` + +Soon to come: a simple pipeline model to first translate input text to English and then use monolingual models! + + ## To contribute diff --git a/summertime/model/base_model.py b/summertime/model/base_model.py index e5740dd8..e69f01c6 100644 --- a/summertime/model/base_model.py +++ b/summertime/model/base_model.py @@ -92,7 +92,7 @@ def generate_basic_description(cls) -> str: ) return basic_description - + # TODO nick: implement this function eventually! # @classmethod # def show_supported_languages(cls) -> str: From 68a8635c343fda4aa79b4da89108a7557d70fc1e Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Tue, 9 Nov 2021 09:45:03 -0500 Subject: [PATCH 08/11] reformatting --- .../single_doc/multilingual/base_multilingual_model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/summertime/model/single_doc/multilingual/base_multilingual_model.py b/summertime/model/single_doc/multilingual/base_multilingual_model.py index 4fa35c1f..08cfab1b 100644 --- a/summertime/model/single_doc/multilingual/base_multilingual_model.py +++ b/summertime/model/single_doc/multilingual/base_multilingual_model.py @@ -5,7 +5,6 @@ import fasttext - class MultilingualSummModel(SingleDocSummModel): lang_tag_dict = None @@ -29,10 +28,8 @@ def assert_summ_input_language(cls, corpus, query): filepath = get_cached_file_path("fasttext", "lid.176.ftz", url) - classifier = fasttext.load_model( - str(filepath) - ) # TODO: change download location, - # do not redownload every time if not necessary + fasttext.FastText.eprint = lambda x: None + classifier = fasttext.load_model(str(filepath)) if all([isinstance(ins, list) for ins in corpus]): prediction = classifier.predict(corpus[0]) From 92ad57f2d2a896e09151863c6d51ae4517d5753d Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Thu, 11 Nov 2021 18:16:52 -0500 Subject: [PATCH 09/11] [skip-ci] fix readme typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fdf48544..61e8aada 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ print(corpus) ``` ### Loading a custom dataset -You can usecustom data using the `CustomDataset` class that loads the data in the SummerTime dataset Class +You can use custom data using the `CustomDataset` class that loads the data in the SummerTime dataset Class ```python from summertime.dataset import CustomDataset From ce859d688f7459cfbb62d3a63451faa36c96af42 Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Thu, 11 Nov 2021 18:36:40 -0500 Subject: [PATCH 10/11] fix mex additional merge conflict --- summertime/model/single_doc/multilingual/mt5_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/summertime/model/single_doc/multilingual/mt5_model.py b/summertime/model/single_doc/multilingual/mt5_model.py index 95932a9b..352507cc 100644 --- a/summertime/model/single_doc/multilingual/mt5_model.py +++ b/summertime/model/single_doc/multilingual/mt5_model.py @@ -140,8 +140,6 @@ def __init__(self, device="cpu"): def summarize(self, corpus, queries=None): self.assert_summ_input_type(corpus, queries) - self.assert_summ_input_language(corpus, queries) - with self.tokenizer.as_target_tokenizer(): batch = self.tokenizer( corpus, From 32a904272766df105b8042e5516d77dd795f6836 Mon Sep 17 00:00:00 2001 From: NickSchoelkopf Date: Sat, 13 Nov 2021 14:32:08 -0500 Subject: [PATCH 11/11] fix changes for merge --- .../model/single_doc/longformer_model.py | 2 +- .../single_doc/multilingual/mbart_model.py | 2 +- .../single_doc/multilingual/mt5_model.py | 218 +++++++++--------- summertime/model/single_doc/pegasus_model.py | 2 +- 4 files changed, 113 insertions(+), 111 deletions(-) diff --git a/summertime/model/single_doc/longformer_model.py b/summertime/model/single_doc/longformer_model.py index eb487bed..f3fdea91 100644 --- a/summertime/model/single_doc/longformer_model.py +++ b/summertime/model/single_doc/longformer_model.py @@ -55,6 +55,6 @@ def show_capability(cls) -> None: "Strengths:\n - Correctly handles longer (> 2000 tokens) corpus.\n\n" "Weaknesses:\n - Less accurate on contexts outside training domain.\n\n" "Initialization arguments:\n " - ' - device: use `device="gpu"` to load onto \n' + ' - device: use `device="cuda"` to load onto an NVIDIA GPU.\n' ) print(f"{basic_description} \n {'#'*20} \n {more_details}") diff --git a/summertime/model/single_doc/multilingual/mbart_model.py b/summertime/model/single_doc/multilingual/mbart_model.py index a623467b..6a5bb8ce 100644 --- a/summertime/model/single_doc/multilingual/mbart_model.py +++ b/summertime/model/single_doc/multilingual/mbart_model.py @@ -93,7 +93,7 @@ def summarize(self, corpus, queries=None): encoded_summaries = self.model.generate( **batch, - forced_bos_token_id=self.tokenizer.lang_code_to_id[lang_code], + decoder_start_token_id=self.tokenizer.lang_code_to_id[lang_code], length_penalty=1.0, num_beams=4, early_stopping=True, diff --git a/summertime/model/single_doc/multilingual/mt5_model.py b/summertime/model/single_doc/multilingual/mt5_model.py index 352507cc..3c12ea97 100644 --- a/summertime/model/single_doc/multilingual/mt5_model.py +++ b/summertime/model/single_doc/multilingual/mt5_model.py @@ -13,115 +13,117 @@ class MT5Model(MultilingualSummModel): is_neural = True is_multilingual = True - lang_tag_dict = { - "am": "am", - "ar": "ar", - "az": "az", - "bn": "bn", - "my": "my", - "zh-CN": "zh-CN", - "zh-TW": "zh-TW", - "en": "en", - "fr": "fr", - "gu": "gu", - "ha": "ha", - "hi": "hi", - "ig": "ig", - "id": "id", - "ja": "ja", - "rn": "rn", - "ko": "ko", - "ky": "ky", - "mr": "mr", - "np": "np", - "om": "om", - "ps": "ps", - "fa": "fa", - "pt": "pt", # missing pidgin from XLSum--does not have ISO 639-1 code - "pa": "pa", - "ru": "ru", - "gd": "gd", - "sr": "sr", - "si": "si", - "so": "so", - "es": "es", - "sw": "sw", - "ta": "ta", - "te": "te", - "th": "th", - "ti": "ti", - "tr": "tr", - "uk": "uk", - "ur": "ur", - "uz": "uz", - "vi": "vi", - "cy": "cy", - "yo": "yo", # <- up to here: langs included in XLSum - "af": "af", - "sq": "sq", - "hy": "hy", - "eu": "eu", - "be": "be", - "bg": "bg", - "ca": "ca", + supported_langs = [ + "am", + "ar", + "az", + "bn", + "my", + "zh-CN", + "zh-TW", + "en", + "fr", + "gu", + "ha", + "hi", + "ig", + "id", + "ja", + "rn", + "ko", + "ky", + "mr", + "np", + "om", + "ps", + "fa", + "pt", # missing pidgin from XLSum--does not have ISO 639-1 code + "pa", + "ru", + "gd", + "sr", + "si", + "so", + "es", + "sw", + "ta", + "te", + "th", + "ti", + "tr", + "uk", + "ur", + "uz", + "vi", + "cy", + "yo", # <- up to here: langs included in XLSum + "af", + "sq", + "hy", + "eu", + "be", + "bg", + "ca", # cebuano has no ISO-639-1 code - "ceb": "ceb", - "ny": "ny", - "co": "co", - "cs": "cs", - "da": "da", - "nl": "nl", - "eo": "eo", - "et": "et", - "tl": "tl", # tagalog in place of filipino - "fi": "fi", - "gl": "gl", - "ka": "ka", - "de": "de", - "el": "el", - "ht": "ht", - "haw": "haw", # hawaiian 639-3 code (not in fasttext id) - "he": "he", - "hmn": "hmn", # hmong 639-3 code (not in fasttext id) - "hu": "hu", - "is": "is", - "ga": "ga", - "it": "it", - "jv": "jv", - "kn": "kn", - "kk": "kk", - "km": "km", - "ku": "ku", - "lo": "lo", - "la": "la", - "lv": "lv", - "lt": "lt", - "lb": "lb", - "mk": "mk", - "mg": "mg", - "ms": "ms", - "ml": "ml", - "mt": "mt", - "mi": "mi", - "mn": "mn", - "ne": "ne", - "no": "no", - "pl": "pl", - "ro": "ro", - "sm": "sm", - "sn": "sn", - "sd": "sd", - "sk": "sk", - "sl": "sl", - "st": "st", - "su": "su", - "sv": "sv", - "tg": "tg", - "fy": "fy", - "xh": "xh", - "yi": "yi", - "zu": "zu", - } + "ceb", + "ny", + "co", + "cs", + "da", + "nl", + "eo", + "et", + "tl", # tagalog in place of filipino + "fi", + "gl", + "ka", + "de", + "el", + "ht", + "haw", # hawaiian 639-3 code (not in fasttext id) + "he", + "hmn", # hmong 639-3 code (not in fasttext id) + "hu", + "is", + "ga", + "it", + "jv", + "kn", + "kk", + "km", + "ku", + "lo", + "la", + "lv", + "lt", + "lb", + "mk", + "mg", + "ms", + "ml", + "mt", + "mi", + "mn", + "ne", + "no", + "pl", + "ro", + "sm", + "sn", + "sd", + "sk", + "sl", + "st", + "su", + "sv", + "tg", + "fy", + "xh", + "yi", + "zu", + ] + + lang_tag_dict = {lang: lang for lang in supported_langs} def __init__(self, device="cpu"): diff --git a/summertime/model/single_doc/pegasus_model.py b/summertime/model/single_doc/pegasus_model.py index 0686e665..e9fc3336 100644 --- a/summertime/model/single_doc/pegasus_model.py +++ b/summertime/model/single_doc/pegasus_model.py @@ -52,6 +52,6 @@ def show_capability(cls): "Weaknesses: \n - High memory usage \n " "Initialization arguments: \n " "- `device = 'cpu'` specifies the device the model is stored on and uses for computation. " - "Use `device='gpu'` to run on an Nvidia GPU." + "Use `device='cuda'` to run on an Nvidia GPU." ) print(f"{basic_description} \n {'#'*20} \n {more_details}")