-
Notifications
You must be signed in to change notification settings - Fork 935
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] pytorch lightning integration - training
- Loading branch information
sash
committed
Jan 30, 2021
1 parent
11b531e
commit debf307
Showing
18 changed files
with
791 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
from typing import Optional | ||
|
||
import pytorch_lightning as pl | ||
from mmf.datasets.multi_dataset_loader import MultiDatasetLoader | ||
from mmf.utils.general import get_batch_size | ||
|
||
|
||
class LightningDataModule(pl.LightningDataModule): | ||
def __init__(self, config): | ||
super().__init__() | ||
self.config = config | ||
self.batch_size = get_batch_size() | ||
|
||
self.train_loader = MultiDatasetLoader("train") | ||
self.val_loader = MultiDatasetLoader("val") | ||
self.test_loader = MultiDatasetLoader("test") | ||
|
||
self.train_loader.load(self.config) | ||
self.val_loader.load(self.config) | ||
self.test_loader.load(self.config) | ||
|
||
def prepare_data(self): | ||
pass | ||
|
||
def setup(self, stage: Optional[str] = None): | ||
pass | ||
|
||
def train_dataloader(self): | ||
return self.train_loader | ||
|
||
def val_dataloader(self): | ||
return self.val_loader | ||
|
||
def test_dataloader(self): | ||
return self.test_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
import logging | ||
|
||
from mmf.common.registry import registry | ||
from mmf.utils.checkpoint import Checkpoint | ||
from pytorch_lightning.callbacks.base import Callback | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LightningLoopCallback(Callback): | ||
def __init__(self, lightning_trainer): | ||
super().__init__() | ||
self.lightning_trainer = lightning_trainer | ||
|
||
def on_init_start(self, trainer): | ||
self._checkpoint = Checkpoint(self.lightning_trainer) | ||
self._checkpoint_interval = ( | ||
self.lightning_trainer.config.training.checkpoint_interval | ||
) | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
registry.register("current_epoch", trainer.current_epoch) | ||
|
||
def on_train_batch_end( | ||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx | ||
): | ||
if trainer.global_step % self._checkpoint_interval == 0: | ||
self._save_checkpoint(trainer) | ||
|
||
# prepare the next batch | ||
self.lightning_trainer.data_module.train_loader.change_dataloader() | ||
|
||
def on_train_end(self, trainer, pl_module): | ||
trainer.run_evaluation(test_mode=False) | ||
|
||
def on_validation_batch_end( | ||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx | ||
): | ||
# TODO: sash Needs implementation - coming soon | ||
self.lightning_trainer.data_module.val_loader.change_dataloader() | ||
|
||
def _save_checkpoint(self, trainer): | ||
logger.info("Checkpoint time. Saving a checkpoint.") | ||
return | ||
# TODO: sash Needs implementation - coming soon |
Oops, something went wrong.