Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] --model_parallel hasn't been implemented for most models #9347

Merged
merged 6 commits into from
Jan 5, 2021

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Dec 29, 2020

Apparently we unleashed --model_parallel in trainer w/o checking if the model supports MP (most don't). This PR:

  • checks whether the model supports MP and asserts otherwise
  • fixes the cl arg help to note that the flag will only work if the model supports MP

As we are gradually starting to build MP-support a cleaner solution will be made in the future, but for now this is good enough to prevent misleading false expectations as reported in #9336

(Also for the future, I'm not sure whether it'd be better to check model.config.architectures, which would be more precise than checking model_type since it's the architectures that may or may not support MP within the same model_type - but that's a different discussion).

Fixes: #9336

@patrickvonplaten, @sgugger

if self.args.model_parallel:
# XXX: ideally this register should be maintained elsewhere so that the trainer could just do
# if model.model_parallel_is_supported()
mp_supported = ["gpt2", "t5"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can check like this for now:

if not hasattr(model, "model_parallel"):
    raise ValueError(f"{model.config.model_type} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that check.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I agree that we should add an assert like this.

Also don't think keeping a list we'd manually have to extend is the best way to go here. Maybe just checking whether the model has the attribute model_parallel is good enough for now....Wdyt?

@stas00
Copy link
Contributor Author

stas00 commented Dec 30, 2020

@alexorona proposed to have the model_parallel method in PreTrainedModel, #9323 (comment) which then would break this code as it'd be then present in all models.

I see this PR as a quick band-aid since we released the new cl arg w/o checking that it always works. And then we will surely improve it as we generalize MP and not leave it this way. This is definitely not how it'll remain in the long run.

@stas00 stas00 added the Model Parallel Model Parallelilsm Implementations label Jan 2, 2021
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as a hotfix.

Regarding model_parallel, if that's to be a method in PreTrainedModel and cannot be used to distinguish between models which are parallelizable and models which are not, I think we can go ahead and add a flag for parallelizable models, like model.parallelizable.

Having a way to distinguish between parallelizable models and non-parallelizable models sounds like a must as we continue adding parallelization support.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with the suggestion from Patrick.

if self.args.model_parallel:
# XXX: ideally this register should be maintained elsewhere so that the trainer could just do
# if model.model_parallel_is_supported()
mp_supported = ["gpt2", "t5"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that check.

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

So should we merge this one as a hot-fix?


An absolute yes to PreTrainedModel.parallelizable accessor - default False, then a True override for each specific model head that implements it - better than checking arch which doesn't guarantee that it'll have all heads parallelizable.

And also what do you think about tests? Currently we hardcore a list of parallelizable models:

all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()

should it remain this way or should we automatically derive those from the model by iterating over all_model_classes:

all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()

and automatically deriving which are parallelizable. Less code to write in the future.

@sgugger
Copy link
Collaborator

sgugger commented Jan 4, 2021

I'd rather merge as a hotfix the proper check and then worry about the tests in a follow up PR (I think we should have a combination of a flag (like for pruning) and checking the models having the attributes there).

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

It no longer will be hot, but yes, I will code that ;) thank you for the feedback, @sgugger

I think we should have a combination of a flag (like for pruning) and checking the models having the attributes there).

I'm not sure what you mean here. An example would be helpful to understand what you propose.

@sgugger
Copy link
Collaborator

sgugger commented Jan 4, 2021

The class ModelTesterMixin has a few attributes that control what common tests to apply. I just realized while reading it that it already has the test_model_parallel flag so this part is done already. All that is left is just to infer the models to test from the presence of the right attribute :-)

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

OK, I added model.is_parallelizable property - let me know if this looks good, or whether you prefer not using a property.

if you prefer w/o is_ or not have it a property please let me know.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this design but it differs from what we were talking about, so we should check the others are fine with it too before merging.

@@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):

- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **_is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A private flag should not appear in the public documentation.

Copy link
Contributor Author

@stas00 stas00 Jan 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, right! so put the doc for the property then?

this seems to be different from others - how should I document it then?

Or perhaps just make it into a public member? what is the standard?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have just this class attribute be public and no property?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was cleaner since it should be read-only, but it's fine as non-property. changed.

I don't think we have a "way" so that's why I'm never sure when something should be a property or a public attribute.

thank you for the feedback/ideas, @sgugger.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH base_model_prefix is the same, it should be read-only in theory but we have it as a simple class attribute, so let's stay simple for this new attribute too :-)

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

I'm fine with this design but it differs from what we were talking about, so we should check the others are fine with it too before merging.

Yes, of course.

that's why it is no longer a hotfix, but it seems to be fine - only one user has filed an issue about using a non-working --model_parallel so far.

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

So since the only change I proposed is from parallelizable to is_parallelizable, do you still think we ought to re-validate with @LysandreJik?

@sgugger
Copy link
Collaborator

sgugger commented Jan 4, 2021

Yes, let's wait for him to review this tomorrow morning (he's on European time for the next month or so).

@stas00 stas00 requested a review from LysandreJik January 4, 2021 21:34
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, this is very clean!

@LysandreJik LysandreJik merged commit 748006c into huggingface:master Jan 5, 2021
@stas00 stas00 deleted the mp-check branch January 5, 2021 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Model Parallel Model Parallelilsm Implementations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

"RuntimeError: Input, output and indices must be on the current device" when trying to finetune MBart
4 participants