-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
5/n: Extract reference model call to plugins/accelerators (#4773)
* Encapsulate extracting reference model within the plugin to allow custom wrapper logic to live within the plugin/accelerators * Add missing new lines * Fix call to accelerator * Removed double blank * Use accelerator backend * Handle case where wrapper has not been initialized within the plugin * Added basic get model tests, add better typing * Change model name * Split GPU/DDP test * Add stronger typing, skip ddp test on windows * Fix import * Fix import in dp * Fixed PEP8 definition * Add ddp launcher for ddp testing * Modify accelerator reference model to property, change name to reflect func * Revert property as this is incorrect.= * Revert across accelerators * Modified name to get_model_from_plugin * Code review changes, fix issue with dp * Add verb to function getter Co-authored-by: chaton <thomas@grid.ai>
- Loading branch information
Showing
10 changed files
with
181 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import sys | ||
|
||
import pytest | ||
import torch | ||
|
||
from pytorch_lightning import Trainer | ||
from tests.backends.launcher import DDPLauncher | ||
from tests.base.boring_model import BoringModel | ||
|
||
|
||
class TrainerGetModel(BoringModel): | ||
def on_fit_start(self): | ||
assert self == self.trainer.get_model() | ||
|
||
def on_fit_end(self): | ||
assert self == self.trainer.get_model() | ||
|
||
|
||
def test_get_model(tmpdir): | ||
""" | ||
Tests that :meth:`trainer.get_model` extracts the model correctly | ||
""" | ||
|
||
model = TrainerGetModel() | ||
|
||
limit_train_batches = 2 | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
limit_train_batches=limit_train_batches, | ||
limit_val_batches=2, | ||
max_epochs=1, | ||
) | ||
trainer.fit(model) | ||
|
||
|
||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") | ||
def test_get_model_ddp_cpu(tmpdir): | ||
""" | ||
Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu | ||
""" | ||
|
||
model = TrainerGetModel() | ||
|
||
limit_train_batches = 2 | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
limit_train_batches=limit_train_batches, | ||
limit_val_batches=2, | ||
max_epochs=1, | ||
accelerator='ddp_cpu', | ||
num_processes=2 | ||
) | ||
trainer.fit(model) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") | ||
def test_get_model_gpu(tmpdir): | ||
""" | ||
Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU | ||
""" | ||
|
||
model = TrainerGetModel() | ||
|
||
limit_train_batches = 2 | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
limit_train_batches=limit_train_batches, | ||
limit_val_batches=2, | ||
max_epochs=1, | ||
gpus=1 | ||
) | ||
trainer.fit(model) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") | ||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") | ||
@DDPLauncher.run("--accelerator [accelerator]", | ||
max_epochs=["1"], | ||
accelerator=["ddp", "ddp_spawn"]) | ||
def test_get_model_ddp_gpu(tmpdir, args=None): | ||
""" | ||
Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators | ||
""" | ||
|
||
model = TrainerGetModel() | ||
|
||
limit_train_batches = 2 | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
limit_train_batches=limit_train_batches, | ||
limit_val_batches=2, | ||
max_epochs=1, | ||
gpus=1, | ||
accelerator=args.accelerator | ||
) | ||
trainer.fit(model) | ||
return 1 |