Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mT5 #98

Merged
merged 14 commits into from
Nov 13, 2021
Merged

Add mT5 #98

merged 14 commits into from
Nov 13, 2021

Conversation

haileyschoelkopf
Copy link
Collaborator

@haileyschoelkopf haileyschoelkopf commented Nov 2, 2021

add mT5 model (using a checkpoint fine-tuned on the XLSum dataset.)

Ready to merge, but still todo:

  • possibly adding the rest of the 101 languages that mT5-base was trained on to supported languages, instead of just including the languages in XLSum as supported languages (~45 languages)

@haileyschoelkopf haileyschoelkopf changed the title Nick/mt5 Add mT5 Nov 2, 2021
@niansong1996
Copy link
Collaborator

Thanks a lot for the PR, Nick!

I haven't got the time to review everything yet, which I will do ASAP.

One thing I noticed is that we don't have the mBART model listed in the Readme.md tables about supported models. Can you add it altogether with mT5 in this PR? Thanks!

@haileyschoelkopf
Copy link
Collaborator Author

Sure, I can add documentation for this and mBART in the PR!

README.md Outdated
@@ -235,7 +238,7 @@ print(corpus)
```

### Loading a custom dataset
You can use load custom data using the `CustomDataset` class that puts the data in the SummerTime dataset Class
You can usecustom data using the `CustomDataset` class that loads the data in the SummerTime dataset Class
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, missed a space here

Copy link
Collaborator

@niansong1996 niansong1996 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. The slight issue is that seems a lot of the changes are duplicates of #96, presumably because of the delay of my reviews on that branch. Sorry about that.

Let's try to merge #96 first and pull from main for this branch.

@@ -22,6 +22,8 @@ def __init__(self, device="cpu"):
def summarize(self, corpus, queries=None):
self.assert_summ_input_type(corpus, queries)

self.assert_summ_input_language(corpus, queries)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made a comment about this in #96, after the refactoring on that branch, let's merge it to main and pull main for this branch so it's fixed here automatically as well.

@niansong1996
Copy link
Collaborator

Okay, now that #96 is merged, we should rebase this branch on main or pull from main?

@haileyschoelkopf
Copy link
Collaborator Author

@niansong1996 This PR should be all set for review now!

@@ -86,6 +86,16 @@ def generate_basic_description(cls) -> str:

return basic_description

# TODO nick: implement this function eventually!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be in the base_model.py or the multingual_model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see that you are adding the function of returning "english" for non-multilingual models. Okay, then this is good.

is_neural = True
is_multilingual = True

lang_tag_dict = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, if all the keys and values are the same, why is it a Dict and not a List? I understand that you want to maintain some consistency across different multi-lingual models. If it's not a mBART-specific thing, then maybe better to store them in a list and initalize the dict with that list. Leaves smaller room for error this way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was using a dict just to stay consistent with mBART. I can change the initialization for this dictionary to be from a list though

"Weaknesses: \n - High memory usage"
"Initialization arguments: \n "
"- `device = 'cpu'` specifies the device the model is stored on and uses for computation. "
"Use `device='gpu'` to run on an Nvidia GPU."
"Use `device='cuda'` to run on an Nvidia GPU."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no. I think this typo is actually common across all our models... Good catch!

But do you mind fixing the others as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix!

@@ -93,7 +93,7 @@ def summarize(self, corpus, queries=None):

encoded_summaries = self.model.generate(
**batch,
decoder_start_token_id=self.tokenizer.lang_code_to_id[lang_code],
forced_bos_token_id=self.tokenizer.lang_code_to_id[lang_code],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what's the difference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think there is a difference, will switch back to decoder_start_token_id because its name is more self explanatory imo

Copy link
Collaborator

@niansong1996 niansong1996 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Left a few comments.

A more general comment is that did you write model-specific tests (for mBART, mT5) in the tests/model_test.py? You can find more templates/examples in that file.

@haileyschoelkopf
Copy link
Collaborator Author

@niansong1996 should be ready for merge! I have written generic tests for multilingual models (using a Spanish language instance from MLSum) but have not written any specific tests as was done for HMNet. Will do that in another PR though!

@niansong1996
Copy link
Collaborator

Awesome! Merging this PR now.

@niansong1996 niansong1996 merged commit db7b2ad into main Nov 13, 2021
@haileyschoelkopf haileyschoelkopf deleted the nick/mT5 branch November 14, 2021 20:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants