Skip to content

Commit

Permalink
moves init apex from LM to apex connector (#3923)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Oct 7, 2020
1 parent c1559a1 commit d71ed27
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 49 deletions.
3 changes: 0 additions & 3 deletions docs/source/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ Training set-up

- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection`
- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.train_dataloader`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.test_dataloader`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.val_dataloader`
Expand Down
11 changes: 0 additions & 11 deletions docs/source/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1024,17 +1024,6 @@ Advanced hooks
^^^^^^^^^^^^^^
Use these hooks to modify advanced functionality

configure_apex
~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_apex
:noindex:

configure_ddp
~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_ddp
:noindex:

configure_sync_batchnorm
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
34 changes: 0 additions & 34 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,40 +985,6 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule

return model

def configure_apex(
self,
amp: object,
model: "LightningModule",
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple["LightningModule", List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)
Return:
Apex wrapped model and optimizers
Examples:
.. code-block:: python
# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

def configure_optimizers(
self,
):
Expand Down
38 changes: 37 additions & 1 deletion pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from torch.optim.optimizer import Optimizer

try:
from apex import amp
Expand All @@ -24,10 +26,44 @@ def __init__(self, trainer):
self.trainer = trainer

def connect(self, model, optimizers):
model, optimizers = model.configure_apex(amp, model, optimizers, self.trainer.amp_level)
model, optimizers = self.configure_apex(amp, model, optimizers, self.trainer.amp_level)
self.trainer.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
return model, optimizers

def training_step(self, fx, args):
output = fx(args)
return output

def configure_apex(
self,
amp: object,
model: "LightningModule",
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple["LightningModule", List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)
Return:
Apex wrapped model and optimizers
Examples:
.. code-block:: python
# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

0 comments on commit d71ed27

Please sign in to comment.