Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] PL mvp0: training #748

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mmf/configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ config_version: 1.0
# Configuration for training
training:
hackgoofer marked this conversation as resolved.
Show resolved Hide resolved
# Name of the trainer class used to define the training/evalution loop
# `mmf` for default trainer, `lightning` for pytorch lightning trainer
# pytorch lightning trainer's params is at `trainer.params`
trainer: mmf
# Seed to be used for training. -1 means random seed between 1 and 100000.
# Either pass fixed through your config or command line arguments
Expand Down Expand Up @@ -131,6 +133,32 @@ training:
# drop in results.
fp16: false

trainer:
# Name of the trainer class used to define the training/evalution loop
# `mmf` or `lightning` to specify the trainer to be used
# `mmf` for mmf trainer,
# for mmf trainer params, please see training params in the `training` config
# `lightning` for Pytorch Lightning trainer
# for lightning trainer params, please see lightning doc for details: ie.,
# https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#trainer-class-api
type: lightning
params:
gpus: null
num_nodes: 1
precision: 32
deterministic: false
benchmark: false
max_steps: 22000
max_epochs: null
gradient_clip_val: 0.0
num_sanity_val_steps: 0
automatic_optimization: true # only True is supported for now
checkpoint_callback: false
accumulate_grad_batches: 1
val_check_interval: 1000
log_every_n_steps: 100
limit_val_batches: 5

# Configuration for evaluation
evaluation:
# Metrics for evaluation
Expand Down
3 changes: 2 additions & 1 deletion mmf/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def init_processors(self):

def prepare_batch(self, batch):
"""
Can be possibly overridden in your child class
Can be possibly overridden in your child class. Not supported w Lightning
trainer

Prepare batch for passing to model. Whatever returned from here will
be directly passed to model's forward function. Currently moves the batch to
Expand Down
29 changes: 29 additions & 0 deletions mmf/datasets/lightning_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import pytorch_lightning as pl
hackgoofer marked this conversation as resolved.
Show resolved Hide resolved
from mmf.datasets.multi_dataset_loader import MultiDatasetLoader
from mmf.utils.general import get_batch_size


class LightningDataModule(pl.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
self.batch_size = get_batch_size()

self.train_loader = MultiDatasetLoader("train")
self.val_loader = MultiDatasetLoader("val")
self.test_loader = MultiDatasetLoader("test")

self.train_loader.load(self.config)
self.val_loader.load(self.config)
self.test_loader.load(self.config)

def train_dataloader(self):
return self.train_loader

def val_dataloader(self):
return self.val_loader

def test_dataloader(self):
return self.test_loader
80 changes: 72 additions & 8 deletions mmf/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.

"""
Models built on top of Pythia need to inherit ``BaseModel`` class and adhere to
some format. To create a model for MMF, follow this quick cheatsheet.
Models built in MMF need to inherit ``BaseModel`` class and adhere to
a fixed format. To create a model for MMF, follow this quick cheatsheet.

1. Inherit ``BaseModel`` class, make sure to call ``super().__init__()`` in your
class's ``__init__`` function.
Expand Down Expand Up @@ -47,21 +47,21 @@ def forward(self, sample_list):
from dataclasses import dataclass
from typing import List, Optional, Union

import pytorch_lightning as pl
from mmf.common.registry import registry
from mmf.common.sample import to_device
from mmf.common.sample import SampleList, to_device
from mmf.modules.losses import LossConfig, Losses
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_current_device
from omegaconf import MISSING, DictConfig, OmegaConf
from torch import nn


logger = logging.getLogger(__name__)


class BaseModel(nn.Module):
class BaseModel(pl.LightningModule):
"""For integration with MMF's trainer, datasets and other features,
models needs to inherit this class, call `super`, write a build function,
write a forward function taking a ``SampleList`` as input and returning a
Expand All @@ -84,8 +84,10 @@ def __init__(self, config: Union[DictConfig, Config]):
config = OmegaConf.structured(config)

self.config = config

self._logged_warning = {"losses_present": False}
self._is_pretrained = False
self._is_pl_enabled = False

@classmethod
def from_params(cls, **kwargs):
Expand All @@ -95,10 +97,18 @@ def from_params(cls, **kwargs):
def is_pretrained(self):
return self._is_pretrained

@property
def is_pl_enabled(self):
return self._is_pl_enabled

@is_pretrained.setter
def is_pretrained(self, x: bool):
self._is_pretrained = x

@is_pl_enabled.setter
def is_pl_enabled(self, x: bool):
self._is_pl_enabled = x

def build(self):
"""Function to be implemented by the child class, in case they need to
build their model separately than ``__init__``. All model related
Expand Down Expand Up @@ -165,17 +175,71 @@ def forward(self, sample_list, *args, **kwargs):
"Forward of the child model class needs to be implemented."
)

def training_step(self, batch, batch_idx, *args, **kwargs):
"""Member function of PL modules. Used only when PL enabled.
To be implemented by child class. Takes in a ``SampleList``,
batch_idx and returns back a dict.

Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration

Returns:
Dict: Dict containing loss.
"""
batch = self._ensure_sample_list(batch)
output = self(batch)
loss_dict = output["losses"]
output["loss"] = sum(loss.mean() for loss in loss_dict.values())
return output

def validation_step(self, batch, batch_idx, *args, **kwargs):
"""Member function of PL modules. Used only when PL enabled.
To be implemented by child class. Takes in a ``SampleList``,
batch_idx and returns back a dict.

Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration

Returns:
Dict
"""
batch = self._ensure_sample_list(batch)
output = self(batch)
# TODO: @sash Implementation coming soon! (next PR)
return output

def configure_optimizers(self):
""" Member function of PL modules. Used only when PL enabled."""
assert self._is_pl_enabled, (
"configure_optimizers should be only used as a member "
"function of LightningModule when pytorch lightning is enabled."
)

from mmf.utils.build import build_lightning_optimizers

config = registry.get("config")
return build_lightning_optimizers(self, config)

def _ensure_sample_list(self, batch):
if not isinstance(batch, SampleList):
# Try converting to SampleList
batch = SampleList(batch)
return batch

def __call__(self, sample_list, *args, **kwargs):
# Move to proper device i.e. same as the model before passing
sample_list = to_device(sample_list, get_current_device())
if not self._is_pl_enabled:
# Move to proper device i.e. same as the model before passing
sample_list = to_device(sample_list, get_current_device())

model_output = super().__call__(sample_list, *args, **kwargs)

# Don't do anything fancy to output if it is pretrained
if self.is_pretrained:
return model_output

# Make sure theat the output from the model is a Mapping
# Make sure that the output from the model is a Mapping
assert isinstance(
model_output, collections.abc.Mapping
), "A dict must be returned from the forward of the model."
Expand Down
2 changes: 1 addition & 1 deletion mmf/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Losses(nn.Module):
mostly doesn't need to use this class.

Attributes:
losses: List containing instanttions of each loss
losses: List containing instantiations of each loss
passed in config
"""

Expand Down
2 changes: 1 addition & 1 deletion mmf/trainers/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, config, trainer):

self._scheduler = None
if self.training_config.lr_scheduler is True:
self._scheduler = build_scheduler(self.trainer.optimizer, self.config)
self._scheduler = build_scheduler(trainer.optimizer, self.config)

def on_update_end(self, **kwargs):
if self._scheduler is not None:
Expand Down
39 changes: 9 additions & 30 deletions mmf/trainers/core/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@

import gc
import logging
import math
import warnings
from abc import ABC
from typing import Any, Dict

import torch
from mmf.common.registry import registry
from mmf.common.report import Report
from mmf.common.sample import to_device
from mmf.utils.general import clip_gradients
from mmf.utils.general import clip_gradients, get_max_updates
from torch import Tensor


Expand Down Expand Up @@ -218,32 +216,13 @@ def _extract_loss(self, report: Dict[str, Any]) -> Tensor:
return loss

def _calculate_max_updates(self):
max_updates = self.training_config.max_updates
max_epochs = self.training_config.max_epochs
if max_updates is None and max_epochs is None:
raise ValueError("Neither max_updates nor max_epochs is specified.")

if isinstance(
self.train_loader.current_dataset, torch.utils.data.IterableDataset
):
warnings.warn(
"max_epochs not supported for Iterable datasets. Falling back "
+ "to max_updates."
)
return max_updates

if max_updates is not None and max_epochs is not None:
warnings.warn(
"Both max_updates and max_epochs are specified. "
+ f"Favoring max_epochs: {max_epochs}"
)

if max_epochs is not None:
max_updates = (
math.ceil(
len(self.train_loader) / self.training_config.update_frequency
)
* max_epochs
)
config_max_updates = self.training_config.max_updates
config_max_epochs = self.training_config.max_epochs
max_updates, _ = get_max_updates(
config_max_updates,
config_max_epochs,
self.train_loader,
self.training_config.update_frequency,
)

return max_updates
1 change: 1 addition & 0 deletions mmf/trainers/lightning_core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Facebook, Inc. and its affiliates.
49 changes: 49 additions & 0 deletions mmf/trainers/lightning_core/loop_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import logging
from typing import Any, List

from mmf.common.registry import registry
from mmf.common.sample import SampleList
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.base import Callback


logger = logging.getLogger(__name__)


class LightningLoopCallback(Callback):
hackgoofer marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, lightning_trainer: Any):
super().__init__()
self.lightning_trainer = lightning_trainer

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
registry.register("current_epoch", trainer.current_epoch)

def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: List,
batch: SampleList,
batch_idx: int,
dataloader_idx: int,
):
# prepare the next batch
self.lightning_trainer.data_module.train_loader.change_dataloader()

def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
# TODO: @sash next PR
pass

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: List,
batch: SampleList,
batch_idx: int,
dataloader_idx: int,
):
# prepare the next batch
self.lightning_trainer.data_module.val_loader.change_dataloader()
Loading