-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[trainer] --model_parallel hasn't been implemented for most models #9347
Conversation
src/transformers/trainer.py
Outdated
if self.args.model_parallel: | ||
# XXX: ideally this register should be maintained elsewhere so that the trainer could just do | ||
# if model.model_parallel_is_supported() | ||
mp_supported = ["gpt2", "t5"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can check like this for now:
if not hasattr(model, "model_parallel"):
raise ValueError(f"{model.config.model_type} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I agree that we should add an assert like this.
Also don't think keeping a list we'd manually have to extend is the best way to go here. Maybe just checking whether the model has the attribute model_parallel
is good enough for now....Wdyt?
@alexorona proposed to have the I see this PR as a quick band-aid since we released the new cl arg w/o checking that it always works. And then we will surely improve it as we generalize MP and not leave it this way. This is definitely not how it'll remain in the long run. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as a hotfix.
Regarding model_parallel
, if that's to be a method in PreTrainedModel
and cannot be used to distinguish between models which are parallelizable and models which are not, I think we can go ahead and add a flag for parallelizable models, like model.parallelizable
.
Having a way to distinguish between parallelizable models and non-parallelizable models sounds like a must as we continue adding parallelization support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with the suggestion from Patrick.
src/transformers/trainer.py
Outdated
if self.args.model_parallel: | ||
# XXX: ideally this register should be maintained elsewhere so that the trainer could just do | ||
# if model.model_parallel_is_supported() | ||
mp_supported = ["gpt2", "t5"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that check.
So should we merge this one as a hot-fix? An absolute yes to And also what do you think about tests? Currently we hardcore a list of parallelizable models: transformers/tests/test_modeling_t5.py Line 491 in 086718a
should it remain this way or should we automatically derive those from the model by iterating over transformers/tests/test_modeling_t5.py Line 489 in 086718a
and automatically deriving which are parallelizable. Less code to write in the future. |
I'd rather merge as a hotfix the proper check and then worry about the tests in a follow up PR (I think we should have a combination of a flag (like for pruning) and checking the models having the attributes there). |
It no longer will be hot, but yes, I will code that ;) thank you for the feedback, @sgugger
I'm not sure what you mean here. An example would be helpful to understand what you propose. |
The class |
OK, I added if you prefer w/o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with this design but it differs from what we were talking about, so we should check the others are fine with it too before merging.
src/transformers/modeling_utils.py
Outdated
@@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | |||
|
|||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in | |||
derived classes of the same architecture adding modules on top of the base model. | |||
- **_is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A private flag should not appear in the public documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, right! so put the doc for the property then?
this seems to be different from others - how should I document it then?
Or perhaps just make it into a public member? what is the standard?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not have just this class attribute be public and no property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it was cleaner since it should be read-only, but it's fine as non-property. changed.
I don't think we have a "way" so that's why I'm never sure when something should be a property or a public attribute.
thank you for the feedback/ideas, @sgugger.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH base_model_prefix
is the same, it should be read-only in theory but we have it as a simple class attribute, so let's stay simple for this new attribute too :-)
Yes, of course. that's why it is no longer a hotfix, but it seems to be fine - only one user has filed an issue about using a non-working |
So since the only change I proposed is from |
Yes, let's wait for him to review this tomorrow morning (he's on European time for the next month or so). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, this is very clean!
Apparently we unleashed
--model_parallel
in trainer w/o checking if the model supports MP (most don't). This PR:As we are gradually starting to build MP-support a cleaner solution will be made in the future, but for now this is good enough to prevent misleading false expectations as reported in #9336
(Also for the future, I'm not sure whether it'd be better to check
model.config.architectures
, which would be more precise than checkingmodel_type
since it's thearchitectures
that may or may not support MP within the samemodel_type
- but that's a different discussion).Fixes: #9336
@patrickvonplaten, @sgugger