-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mT5 #98
Add mT5 #98
Conversation
Thanks a lot for the PR, Nick! I haven't got the time to review everything yet, which I will do ASAP. One thing I noticed is that we don't have the mBART model listed in the |
Sure, I can add documentation for this and mBART in the PR! |
README.md
Outdated
@@ -235,7 +238,7 @@ print(corpus) | |||
``` | |||
|
|||
### Loading a custom dataset | |||
You can use load custom data using the `CustomDataset` class that puts the data in the SummerTime dataset Class | |||
You can usecustom data using the `CustomDataset` class that loads the data in the SummerTime dataset Class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooops, missed a space here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -22,6 +22,8 @@ def __init__(self, device="cpu"): | |||
def summarize(self, corpus, queries=None): | |||
self.assert_summ_input_type(corpus, queries) | |||
|
|||
self.assert_summ_input_language(corpus, queries) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made a comment about this in #96, after the refactoring on that branch, let's merge it to main and pull main for this branch so it's fixed here automatically as well.
Okay, now that #96 is merged, we should rebase this branch on main or pull from main? |
@niansong1996 This PR should be all set for review now! |
@@ -86,6 +86,16 @@ def generate_basic_description(cls) -> str: | |||
|
|||
return basic_description | |||
|
|||
# TODO nick: implement this function eventually! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be in the base_model.py
or the multingual_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see that you are adding the function of returning "english" for non-multilingual models. Okay, then this is good.
is_neural = True | ||
is_multilingual = True | ||
|
||
lang_tag_dict = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, if all the keys and values are the same, why is it a Dict
and not a List
? I understand that you want to maintain some consistency across different multi-lingual models. If it's not a mBART-specific thing, then maybe better to store them in a list and initalize the dict with that list. Leaves smaller room for error this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was using a dict just to stay consistent with mBART. I can change the initialization for this dictionary to be from a list though
"Weaknesses: \n - High memory usage" | ||
"Initialization arguments: \n " | ||
"- `device = 'cpu'` specifies the device the model is stored on and uses for computation. " | ||
"Use `device='gpu'` to run on an Nvidia GPU." | ||
"Use `device='cuda'` to run on an Nvidia GPU." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, no. I think this typo is actually common across all our models... Good catch!
But do you mind fixing the others as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will fix!
@@ -93,7 +93,7 @@ def summarize(self, corpus, queries=None): | |||
|
|||
encoded_summaries = self.model.generate( | |||
**batch, | |||
decoder_start_token_id=self.tokenizer.lang_code_to_id[lang_code], | |||
forced_bos_token_id=self.tokenizer.lang_code_to_id[lang_code], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, what's the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't think there is a difference, will switch back to decoder_start_token_id because its name is more self explanatory imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Left a few comments.
A more general comment is that did you write model-specific tests (for mBART, mT5) in the tests/model_test.py
? You can find more templates/examples in that file.
@niansong1996 should be ready for merge! I have written generic tests for multilingual models (using a Spanish language instance from MLSum) but have not written any specific tests as was done for HMNet. Will do that in another PR though! |
Awesome! Merging this PR now. |
add mT5 model (using a checkpoint fine-tuned on the XLSum dataset.)
Ready to merge, but still todo: