Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mT5 #98

Merged
merged 14 commits into from
Nov 13, 2021
44 changes: 34 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,18 @@ Also, please run our colab notebook for a more hands-on demo and more examples.
## Models

### Supported Models
SummerTime supports different models (e.g., TextRank, BART, Longformer) as well as model wrappers for more complex summariztion tasks (e.g., JointModel for multi-doc summarzation, BM25 retrieval for query-based summarization).

| Models | Single-doc | Multi-doc | Dialogue-based | Query-based |
| --------- | :------------------: | :------------------: | :------------------: | :------------------: |
| BartModel | :heavy_check_mark: | | | |
| BM25SummModel | | | | :heavy_check_mark: |
| HMNetModel | | | :heavy_check_mark: | |
| LexRankModel | :heavy_check_mark: | | | |
| LongformerModel | :heavy_check_mark: | | | |
SummerTime supports different models (e.g., TextRank, BART, Longformer) as well as model wrappers for more complex summarization tasks (e.g., JointModel for multi-doc summarzation, BM25 retrieval for query-based summarization). Several multilingual models are also supported (mT5 and mBART).


| Models | Single-doc | Multi-doc | Dialogue-based | Query-based | Multilingual |
| --------- | :------------------: | :------------------: | :------------------: | :------------------: | :------------------: |
| BartModel | :heavy_check_mark: | | | | |
| BM25SummModel | | | | :heavy_check_mark: | |
| HMNetModel | | | :heavy_check_mark: | | |
| LexRankModel | :heavy_check_mark: | | | | |
| LongformerModel | :heavy_check_mark: | | | | |
| MBartModel | :heavy_check_mark: | | | | 50 languages (Arabic, Czech, German, English, Spanish, Estonian, Finnish, French, Gujarati, Hindi, Italian, Japanese, Kazakh, Korean, Lithuanian, Latvian, Burmese, Nepali, Dutch, Romanian, Russian, Sinhala, Turkish, Vietnamese, Chinese, Afrikaans, Azerbaijani, Bengali, Persian, Hebrew, Croatian, Indonesian, Georgian, Khmer, Macedonian, Malayalam, Mongolian, Marathi, Polish, Pashto, Portuguese, Swedish, Tamil, Telugu, Thai, Tagalog, Ukrainian, Urdu, Xhosa, Slovenian) |
| MT5Model | :heavy_check_mark: | | | | 101 languages (full list [here](https://github.com/google-research/multilingual-t5#readme)) |
| MultiDocJointModel | | :heavy_check_mark: | | |
| MultiDocSeparateModel | | :heavy_check_mark: | | |
| PegasusModel | :heavy_check_mark: | | | |
Expand Down Expand Up @@ -235,7 +238,7 @@ print(corpus)
```

### Loading a custom dataset
You can use load custom data using the `CustomDataset` class that puts the data in the SummerTime dataset Class
You can use custom data using the `CustomDataset` class that loads the data in the SummerTime dataset Class
```python
from summertime.dataset import CustomDataset

Expand Down Expand Up @@ -298,6 +301,7 @@ train_set = itertools.islice(cnn_dataset.train_set, 5)
corpus = [instance.source for instance in train_set]



# Example 1 - traditional non-neural model
# LexRank model
lexrank = model.LexRankModel(corpus)
Expand Down Expand Up @@ -325,7 +329,26 @@ longformer_summary = longformer.summarize(corpus)
print(longformer_summary)
```

### Multilingual summarization
The `summarize()` method of multilingual models automatically checks for input document language.

Single-doc multilingual models can be initialized and used in the same way as monolingual models. They return an error if a language not supported by the model is input.

```python
mbart_model = st_model.MBartModel()
mt5_model = st_model.MT5Model()

# load Spanish portion of MLSum dataset
mlsum = datasets.MlsumDataset(["es"])

corpus = itertools.islice(mlsum.train_set, 5)
corpus = [instance.source for instance in train_set]

# mt5 model will automatically detect Spanish as the language and indicate that this is supported!
mt5_model.summarize()
```

Soon to come: a simple pipeline model to first translate input text to English and then use monolingual models!

## Evaluation
SummerTime supports different evaluation metrics including: BertScore, Bleu, Meteor, Rouge, RougeWe
Expand Down Expand Up @@ -426,6 +449,7 @@ query_based_multi_doc_models = assemble_model_pipeline(QMsumDataset)
# ]
```

=======
### Visualizing performance of different models on your dataset
Given a SummerTime dataset, you may use the pipelines.assemble_model_pipeline function to retrieve a list of initialized SummerTime models that are compatible with the dataset provided.

Expand Down
2 changes: 2 additions & 0 deletions summertime/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LongformerModel,
PegasusModel,
TextRankModel,
MT5Model,
)
from .multi_doc import MultiDocJointModel, MultiDocSeparateModel
from .dialogue import HMNetModel, FlattenDialogueModel
Expand All @@ -14,6 +15,7 @@
SUPPORTED_SUMM_MODELS = [
BartModel,
MBartModel,
MT5Model,
LexRankModel,
LongformerModel,
PegasusModel,
Expand Down
10 changes: 10 additions & 0 deletions summertime/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def generate_basic_description(cls) -> str:

return basic_description

# TODO nick: implement this function eventually!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be in the base_model.py or the multingual_model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see that you are adding the function of returning "english" for non-multilingual models. Okay, then this is good.

# @classmethod
# def show_supported_languages(cls) -> str:
# """
# Returns a list of supported languages for summarization.
# """
# raise NotImplementedError(
# "The base class for models shouldn't be instantiated!"
# )


class SummPipeline(SummModel):
"""
Expand Down
1 change: 1 addition & 0 deletions summertime/model/single_doc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .textrank_model import TextRankModel

from .multilingual import MBartModel
from .multilingual import MT5Model
4 changes: 4 additions & 0 deletions summertime/model/single_doc/base_single_doc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ def assert_summ_input_type(cls, corpus, query):
print(warning)

return "en" # ISO-639-1 code for English

# @classmethod
# def show_supported_languages(cls) -> str:
# return "english"
2 changes: 1 addition & 1 deletion summertime/model/single_doc/longformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ def show_capability(cls) -> None:
"Strengths:\n - Correctly handles longer (> 2000 tokens) corpus.\n\n"
"Weaknesses:\n - Less accurate on contexts outside training domain.\n\n"
"Initialization arguments:\n "
' - device: use `device="gpu"` to load onto \n'
' - device: use `device="cuda"` to load onto an NVIDIA GPU.\n'
)
print(f"{basic_description} \n {'#'*20} \n {more_details}")
1 change: 1 addition & 0 deletions summertime/model/single_doc/multilingual/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mbart_model import MBartModel
from .mt5_model import MT5Model
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from summertime.model.single_doc.base_single_doc_model import SingleDocSummModel

import urllib.request
from summertime.util.download_utils import (
get_cached_file_path,
)
import fasttext
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -30,10 +31,11 @@ def assert_summ_input_type(cls, corpus, query):
super().assert_summ_input_type(corpus, query)

url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz"
# currently using compressed fasttext model from FB
urllib.request.urlretrieve(url, "lid.176.ftz")

classifier = fasttext.load_model("./lid.176.ftz")
filepath = get_cached_file_path("fasttext", "lid.176.ftz", url)

fasttext.FastText.eprint = lambda x: None
classifier = fasttext.load_model(str(filepath))

# fasttext returns a tuple of 2 lists:
# the first list contains a list of predicted language labels
Expand All @@ -54,11 +56,18 @@ def assert_summ_input_type(cls, corpus, query):

# check if language code is in the supported language dictionary
if label in cls.lang_tag_dict:
print(f"Language '{label}' detected.")
print(f"Supported language '{label}' detected.")
return cls.lang_tag_dict[label]
else:
raise ValueError(
f"Unsupported language '{label}'' detected!\n\
Try checking if another of our multilingual models \
supports this language."
f"Unsupported language '{label}' detected! \
Try checking if another of our multilingual models \
supports this language."
)

# @classmethod
# def show_supported_languages(
# cls,
# ):
# langs = [iso639.to_name(lang) for lang in cls.lang_tag_dict.keys()]
# return " ".join(langs)
3 changes: 2 additions & 1 deletion summertime/model/single_doc/multilingual/mbart_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def show_capability(cls) -> None:
"Introduced in 2020, a multilingual variant of BART (a large neural model) "
"trained on web crawl data.\n"
"Strengths: \n - Multilinguality: supports 50 different languages\n"
" - Higher max input length than mT5 (1024)"
"Weaknesses: \n - High memory usage"
"Initialization arguments: \n "
"- `device = 'cpu'` specifies the device the model is stored on and uses for computation. "
"Use `device='gpu'` to run on an Nvidia GPU."
"Use `device='cuda'` to run on an Nvidia GPU."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no. I think this typo is actually common across all our models... Good catch!

But do you mind fixing the others as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix!

)
print(f"{basic_description} \n {'#'*20} \n {more_details}")
176 changes: 176 additions & 0 deletions summertime/model/single_doc/multilingual/mt5_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
from .base_multilingual_model import MultilingualSummModel


class MT5Model(MultilingualSummModel):
"""
MT5 Model for Multilingual Summarization
"""

# static variables
model_name = "mT5"
is_extractive = False
is_neural = True
is_multilingual = True

supported_langs = [
"am",
"ar",
"az",
"bn",
"my",
"zh-CN",
"zh-TW",
"en",
"fr",
"gu",
"ha",
"hi",
"ig",
"id",
"ja",
"rn",
"ko",
"ky",
"mr",
"np",
"om",
"ps",
"fa",
"pt", # missing pidgin from XLSum--does not have ISO 639-1 code
"pa",
"ru",
"gd",
"sr",
"si",
"so",
"es",
"sw",
"ta",
"te",
"th",
"ti",
"tr",
"uk",
"ur",
"uz",
"vi",
"cy",
"yo", # <- up to here: langs included in XLSum
"af",
"sq",
"hy",
"eu",
"be",
"bg",
"ca",
# cebuano has no ISO-639-1 code
"ceb",
"ny",
"co",
"cs",
"da",
"nl",
"eo",
"et",
"tl", # tagalog in place of filipino
"fi",
"gl",
"ka",
"de",
"el",
"ht",
"haw", # hawaiian 639-3 code (not in fasttext id)
"he",
"hmn", # hmong 639-3 code (not in fasttext id)
"hu",
"is",
"ga",
"it",
"jv",
"kn",
"kk",
"km",
"ku",
"lo",
"la",
"lv",
"lt",
"lb",
"mk",
"mg",
"ms",
"ml",
"mt",
"mi",
"mn",
"ne",
"no",
"pl",
"ro",
"sm",
"sn",
"sd",
"sk",
"sl",
"st",
"su",
"sv",
"tg",
"fy",
"xh",
"yi",
"zu",
]

lang_tag_dict = {lang: lang for lang in supported_langs}

def __init__(self, device="cpu"):

super(MT5Model, self).__init__(
trained_domain="News",
max_input_length=512,
max_output_length=None,
)

self.device = device

model_name = "csebuetnlp/mT5_multilingual_XLSum"
self.tokenizer = MT5Tokenizer.from_pretrained(model_name)
self.model = MT5ForConditionalGeneration.from_pretrained(model_name).to(device)

def summarize(self, corpus, queries=None):
self.assert_summ_input_type(corpus, queries)

with self.tokenizer.as_target_tokenizer():
batch = self.tokenizer(
corpus,
truncation=True,
padding="longest",
max_length=self.max_input_length,
return_tensors="pt",
).to(self.device)

encoded_summaries = self.model.generate(
**batch, num_beams=4, length_penalty=1.0, early_stopping=True
)

summaries = self.tokenizer.batch_decode(
encoded_summaries, skip_special_tokens=True
)

return summaries

@classmethod
def show_capability(cls) -> None:
basic_description = cls.generate_basic_description()
more_details = (
"Introduced in ____, a massively multilingual variant of Google's T5, a large neural model. "
"Trained on web crawled data and fine-tuned on XLSum, a 45-language multilingual news dataset.\n"
"Strengths: \n - Massively multilingual: supports 101 different languages\n"
"Weaknesses: \n - High memory usage\n - Lower max input length (512)"
"Initialization arguments: \n "
"- `device = 'cpu'` specifies the device the model is stored on and uses for computation. "
"Use `device='cuda'` to run on an Nvidia GPU."
)
print(f"{basic_description} \n {'#'*20} \n {more_details}")
2 changes: 1 addition & 1 deletion summertime/model/single_doc/pegasus_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ def show_capability(cls):
"Weaknesses: \n - High memory usage \n "
"Initialization arguments: \n "
"- `device = 'cpu'` specifies the device the model is stored on and uses for computation. "
"Use `device='gpu'` to run on an Nvidia GPU."
"Use `device='cuda'` to run on an Nvidia GPU."
)
print(f"{basic_description} \n {'#'*20} \n {more_details}")