-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Connect the model to the training type plugin at the start of run #8536
Conversation
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #8536 +/- ##
======================================
Coverage 93% 93%
======================================
Files 167 169 +2
Lines 14037 14072 +35
======================================
+ Hits 13008 13043 +35
Misses 1029 1029 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
""" | ||
self.setup_training_type_plugin(model) | ||
self.setup_training_type_plugin() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small note: This could be an issue if we ever decide to expose the Accelerator API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why an issue? There's the connect
hook to connect the model already.
self.model
should be available for the plugin when setup
is called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carmocca is this an invariant? should the accelerator assert that the model is available before calling setup training type plugin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ananthsub We could
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Pull request was converted to draft
def setup_environment(self) -> None: | ||
super().setup_environment() | ||
model_call_configure_sharded_model_hook = getattr( | ||
self.lightning_module, "call_configure_sharded_model_hook", False | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documenting the explanation for this change:
Without this, the test tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py::test_fully_sharded_plugin_checkpoint
fails.
That test calls model.setup()
inside the on_load_checkpoint
hook. Its setup
implementation manually modifies the value of self.call_configure_sharded_model_hook
:
def setup(self, stage: str) -> None:
self.call_configure_sharded_model_hook = False
...
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.setup("fit")
This means, that before this change, the call_configure_sharded_model_hook
value would diverge between the model and the training type plugin. Because its check in the training type plugin was done in the connect
hook. (See the diff)
Some pseudocode to illustrate:
Order before this PR
load_checkpoint_weights
model.on_load_checkpoint (in the test)
model.setup -> this sets model.call_configure_sharded
acc.connect -> this checks model.call_configure_sharded
acc.setup_env
model.setup
configure_sharded_model -> this sets model.call_configure_sharded
acc.setup
Order after this PR without this change:
acc.connect -> this checks model.call_configure_sharded
load_checkpoint_weights
model.on_load_checkpoint (in the test)
model.setup -> this sets model.call_configure_sharded **BUG** as we already checked it
acc.setup_env
model.setup
configure_sharded_model -> this sets model.call_configure_sharded
acc.setup
Order after this PR with this change:
acc.connect
load_checkpoint_weights
model.on_load_checkpoint (in the test)
model.setup -> this sets model.call_configure_sharded
acc.setup_env -> this checks model.call_configure_sharded
model.setup
configure_sharded_model -> this sets model.call_configure_sharded
acc.setup
""" | ||
self.setup_training_type_plugin(model) | ||
self.setup_training_type_plugin() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carmocca is this an invariant? should the accelerator assert that the model is available before calling setup training type plugin?
What does this PR do?
With this PR, the
trainer.lightning_module
reference will be set at the very beginning of_run
. Concretely, in theself.accelerator.connect(model)
call.We want this because we need it available during all hooks, even the earliest ones like
setup
.Also, If the trainer had been reused and the model has changed, the
trainer.lightning_module
reference would be stale.Part of #8498
Does your PR introduce any breaking changes ? If yes, please list them.
(BETA
TrainingTypePlugin
API):The
setup
hook no longer takes amodel
argument.Before submitting
PR review