Skip to content

Commit

Permalink
[feat] pytorch lightning integration - training
Browse files Browse the repository at this point in the history
  • Loading branch information
sash committed Jan 30, 2021
1 parent 11b531e commit debf307
Show file tree
Hide file tree
Showing 18 changed files with 791 additions and 55 deletions.
1 change: 1 addition & 0 deletions mmf/configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ config_version: 1.0
# Configuration for training
training:
# Name of the trainer class used to define the training/evalution loop
# `mmf` for default trainer, `lightning` for Pytorch Lightning trainer
trainer: mmf
# Seed to be used for training. -1 means random seed between 1 and 100000.
# Either pass fixed through your config or command line arguments
Expand Down
2 changes: 1 addition & 1 deletion mmf/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def init_processors(self):

def prepare_batch(self, batch):
"""
Can be possibly overridden in your child class
Can be possibly overridden in your child class. Deprecated w Pytorch Lightning
Prepare batch for passing to model. Whatever returned from here will
be directly passed to model's forward function. Currently moves the batch to
Expand Down
37 changes: 37 additions & 0 deletions mmf/datasets/lightning_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates.

from typing import Optional

import pytorch_lightning as pl
from mmf.datasets.multi_dataset_loader import MultiDatasetLoader
from mmf.utils.general import get_batch_size


class LightningDataModule(pl.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
self.batch_size = get_batch_size()

self.train_loader = MultiDatasetLoader("train")
self.val_loader = MultiDatasetLoader("val")
self.test_loader = MultiDatasetLoader("test")

self.train_loader.load(self.config)
self.val_loader.load(self.config)
self.test_loader.load(self.config)

def prepare_data(self):
pass

def setup(self, stage: Optional[str] = None):
pass

def train_dataloader(self):
return self.train_loader

def val_dataloader(self):
return self.val_loader

def test_dataloader(self):
return self.test_loader
99 changes: 91 additions & 8 deletions mmf/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.

"""
Models built on top of Pythia need to inherit ``BaseModel`` class and adhere to
Models built on top of MMF need to inherit ``BaseModel`` class and adhere to
some format. To create a model for MMF, follow this quick cheatsheet.
1. Inherit ``BaseModel`` class, make sure to call ``super().__init__()`` in your
Expand Down Expand Up @@ -48,19 +48,29 @@ def forward(self, sample_list):
from typing import Union

from mmf.common.registry import registry
from mmf.common.sample import to_device
from mmf.common.sample import SampleList, to_device
from mmf.modules.losses import Losses
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from omegaconf import MISSING, DictConfig, OmegaConf
from torch import nn


try:
import pytorch_lightning as pl
except ImportError:
print(
"BaseModel requires Pytorch Lightning. "
+ "Please follow the installation here: "
+ "https://pytorch-lightning.readthedocs.io/"
+ "en/latest/introduction_guide.html"
)
raise

logger = logging.getLogger(__name__)


class BaseModel(nn.Module):
class BaseModel(pl.LightningModule):
"""For integration with MMF's trainer, datasets and other features,
models needs to inherit this class, call `super`, write a build function,
write a forward function taking a ``SampleList`` as input and returning a
Expand All @@ -82,8 +92,10 @@ def __init__(self, config: Union[DictConfig, Config]):
config = OmegaConf.structured(config)

self.config = config

self._logged_warning = {"losses_present": False}
self._is_pretrained = False
self._is_pl_enabled = False

@classmethod
def from_params(cls, **kwargs):
Expand All @@ -93,10 +105,18 @@ def from_params(cls, **kwargs):
def is_pretrained(self):
return self._is_pretrained

@property
def is_pl_enabled(self):
return self._is_pl_enabled

@is_pretrained.setter
def is_pretrained(self, x: bool):
self._is_pretrained = x

@is_pl_enabled.setter
def is_pl_enabled(self, x: bool):
self._is_pl_enabled = x

def build(self):
"""Function to be implemented by the child class, in case they need to
build their model separately than ``__init__``. All model related
Expand Down Expand Up @@ -163,18 +183,81 @@ def forward(self, sample_list, *args, **kwargs):
"Forward of the child model class needs to be implemented."
)

def training_step(self, batch, batch_idx, *args, **kwargs):
"""Member function of PL modules. Used only when PL enabled.
To be implemented by child class. Takes in a ``SampleList``,
batch_idx and returns back a dict.
Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration
Returns:
Dict: Dict containing loss.
"""
batch = self._ensure_sample_list(batch)
output = self(batch)
loss_dict = output["losses"]
output["loss"] = sum([loss.mean() for loss in loss_dict.values()])
return output

def validation_step(self, batch, batch_idx, *args, **kwargs):
"""Member function of PL modules. Used only when PL enabled.
To be implemented by child class. Takes in a ``SampleList``,
batch_idx and returns back a dict.
Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration
Returns:
Dict
"""
batch = self._ensure_sample_list(batch)
output = self(batch)
# TODO: @sash Implementation coming soon! (next PR)
return output

def configure_optimizers(self):
""" Member function of PL modules. Used only when PL enabled."""
assert self._is_pl_enabled, (
"configure_optimizers should be only used as a member "
"function of LightningModule when pytorch lightning is enabled."
)

from mmf.utils.build import build_optimizer, build_scheduler

config = registry.get("config")
optimizer = build_optimizer(self, config)

if config.training.lr_scheduler:
lr_scheduler = build_scheduler(optimizer, config)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
}
else:
return optimizer

def _ensure_sample_list(self, batch):
if not isinstance(batch, SampleList):
# Try converting to SampleList
batch = SampleList(batch)
return batch

def __call__(self, sample_list, *args, **kwargs):
# Move to proper device i.e. same as the model before passing
model_device = next(self.parameters()).device
sample_list = to_device(sample_list, model_device)
if not self._is_pl_enabled:
# Move to proper device i.e. same as the model before passing
model_device = next(self.parameters()).device
sample_list = to_device(sample_list, model_device)

model_output = super().__call__(sample_list, *args, **kwargs)

# Don't do anything fancy to output if it is pretrained
if self.is_pretrained:
return model_output

# Make sure theat the output from the model is a Mapping
# Make sure that the output from the model is a Mapping
assert isinstance(
model_output, collections.abc.Mapping
), "A dict must be returned from the forward of the model."
Expand Down
5 changes: 2 additions & 3 deletions mmf/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Losses(nn.Module):
mostly doesn't need to use this class.
Attributes:
losses: List containing instanttions of each loss
losses: List containing instantiations of each loss
passed in config
"""

Expand Down Expand Up @@ -322,8 +322,7 @@ def forward(self, sample_list, model_output):

@registry.register_loss("nll_loss")
class NLLLoss(nn.Module):
"""Negative log likelikehood loss.
"""
"""Negative log likelikehood loss."""

def __init__(self):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion mmf/trainers/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, config, trainer):

self._scheduler = None
if self.training_config.lr_scheduler is True:
self._scheduler = build_scheduler(self.trainer.optimizer, self.config)
self._scheduler = build_scheduler(trainer.optimizer, self.config)

def on_update_end(self, **kwargs):
if self._scheduler is not None:
Expand Down
39 changes: 9 additions & 30 deletions mmf/trainers/core/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@

import gc
import logging
import math
import warnings
from abc import ABC
from typing import Any, Dict

import torch
from mmf.common.registry import registry
from mmf.common.report import Report
from mmf.common.sample import to_device
from mmf.utils.general import clip_gradients
from mmf.utils.general import clip_gradients, get_max_updates
from torch import Tensor


Expand Down Expand Up @@ -218,32 +216,13 @@ def _extract_loss(self, report: Dict[str, Any]) -> Tensor:
return loss

def _calculate_max_updates(self):
max_updates = self.training_config.max_updates
max_epochs = self.training_config.max_epochs
if max_updates is None and max_epochs is None:
raise ValueError("Neither max_updates nor max_epochs is specified.")

if isinstance(
self.train_loader.current_dataset, torch.utils.data.IterableDataset
):
warnings.warn(
"max_epochs not supported for Iterable datasets. Falling back "
+ "to max_updates."
)
return max_updates

if max_updates is not None and max_epochs is not None:
warnings.warn(
"Both max_updates and max_epochs are specified. "
+ f"Favoring max_epochs: {max_epochs}"
)

if max_epochs is not None:
max_updates = (
math.ceil(
len(self.train_loader) / self.training_config.update_frequency
)
* max_epochs
)
config_max_updates = self.training_config.max_updates
config_max_epochs = self.training_config.max_epochs
max_updates, _ = get_max_updates(
config_max_updates,
config_max_epochs,
self.train_loader,
self.training_config.update_frequency,
)

return max_updates
1 change: 1 addition & 0 deletions mmf/trainers/lightning_core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Facebook, Inc. and its affiliates.
48 changes: 48 additions & 0 deletions mmf/trainers/lightning_core/loop_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import logging

from mmf.common.registry import registry
from mmf.utils.checkpoint import Checkpoint
from pytorch_lightning.callbacks.base import Callback


logger = logging.getLogger(__name__)


class LightningLoopCallback(Callback):
def __init__(self, lightning_trainer):
super().__init__()
self.lightning_trainer = lightning_trainer

def on_init_start(self, trainer):
self._checkpoint = Checkpoint(self.lightning_trainer)
self._checkpoint_interval = (
self.lightning_trainer.config.training.checkpoint_interval
)

def on_train_start(self, trainer, pl_module):
registry.register("current_epoch", trainer.current_epoch)

def on_train_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
if trainer.global_step % self._checkpoint_interval == 0:
self._save_checkpoint(trainer)

# prepare the next batch
self.lightning_trainer.data_module.train_loader.change_dataloader()

def on_train_end(self, trainer, pl_module):
trainer.run_evaluation(test_mode=False)

def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
# TODO: sash Needs implementation - coming soon
self.lightning_trainer.data_module.val_loader.change_dataloader()

def _save_checkpoint(self, trainer):
logger.info("Checkpoint time. Saving a checkpoint.")
return
# TODO: sash Needs implementation - coming soon
Loading

0 comments on commit debf307

Please sign in to comment.