From b94107e4167e58c41fc48e6fb7de12b60f9ee186 Mon Sep 17 00:00:00 2001 From: sash Date: Thu, 14 Jan 2021 21:04:31 -0800 Subject: [PATCH] [feat] pytorch lightning integration - training --- mmf/configs/defaults.yaml | 28 +++++ mmf/datasets/base_dataset.py | 3 +- mmf/datasets/lightning_datamodule.py | 29 +++++ mmf/models/base_model.py | 80 +++++++++++-- mmf/modules/losses.py | 2 +- mmf/trainers/callbacks/lr_scheduler.py | 2 +- mmf/trainers/core/training_loop.py | 39 ++---- mmf/trainers/lightning_core/__init__.py | 1 + mmf/trainers/lightning_core/loop_callback.py | 49 ++++++++ mmf/trainers/lightning_trainer.py | 111 ++++++++++++++++++ mmf/utils/build.py | 13 ++ mmf/utils/general.py | 29 +++++ pyproject.toml | 2 +- requirements.txt | 1 + tests/test_utils.py | 31 ++++- tests/trainers/lightning/__init__.py | 1 + .../lightning/lightning_trainer_mock.py | 52 ++++++++ .../lightning/test_grad_accumulate.py | 24 ++++ .../trainers/lightning/test_grad_clipping.py | 62 ++++++++++ .../lightning/test_loop_conditions.py | 35 ++++++ tests/trainers/lightning/test_loss.py | 38 ++++++ tests/trainers/lightning/test_lr_schedule.py | 39 ++++++ tests/trainers/lightning/test_utils.py | 111 ++++++++++++++++++ tests/trainers/test_fp16.py | 3 +- tests/trainers/test_training_loop.py | 31 ++++- 25 files changed, 763 insertions(+), 53 deletions(-) create mode 100644 mmf/datasets/lightning_datamodule.py create mode 100644 mmf/trainers/lightning_core/__init__.py create mode 100644 mmf/trainers/lightning_core/loop_callback.py create mode 100644 mmf/trainers/lightning_trainer.py create mode 100644 tests/trainers/lightning/__init__.py create mode 100644 tests/trainers/lightning/lightning_trainer_mock.py create mode 100644 tests/trainers/lightning/test_grad_accumulate.py create mode 100644 tests/trainers/lightning/test_grad_clipping.py create mode 100644 tests/trainers/lightning/test_loop_conditions.py create mode 100644 tests/trainers/lightning/test_loss.py create mode 100644 tests/trainers/lightning/test_lr_schedule.py create mode 100644 tests/trainers/lightning/test_utils.py diff --git a/mmf/configs/defaults.yaml b/mmf/configs/defaults.yaml index c5f0aee9e..f1d55dbba 100644 --- a/mmf/configs/defaults.yaml +++ b/mmf/configs/defaults.yaml @@ -4,6 +4,8 @@ config_version: 1.0 # Configuration for training training: # Name of the trainer class used to define the training/evalution loop + # `mmf` for default trainer, `lightning` for pytorch lightning trainer + # pytorch lightning trainer's params is at `trainer.params` trainer: mmf # Seed to be used for training. -1 means random seed between 1 and 100000. # Either pass fixed through your config or command line arguments @@ -131,6 +133,32 @@ training: # drop in results. fp16: false +trainer: + # Name of the trainer class used to define the training/evalution loop + # `mmf` or `lightning` to specify the trainer to be used + # `mmf` for mmf trainer, + # for mmf trainer params, please see training params in the `training` config + # `lightning` for Pytorch Lightning trainer + # for lightning trainer params, please see lightning doc for details: ie., + # https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#trainer-class-api + type: lightning + params: + gpus: null + num_nodes: 1 + precision: 32 + deterministic: false + benchmark: false + max_steps: 22000 + max_epochs: null + gradient_clip_val: 0.0 + num_sanity_val_steps: 0 + automatic_optimization: true # only True is supported for now + checkpoint_callback: false + accumulate_grad_batches: 1 + val_check_interval: 1000 + log_every_n_steps: 100 + limit_val_batches: 5 + # Configuration for evaluation evaluation: # Metrics for evaluation diff --git a/mmf/datasets/base_dataset.py b/mmf/datasets/base_dataset.py index 56677beb2..82e351d5b 100644 --- a/mmf/datasets/base_dataset.py +++ b/mmf/datasets/base_dataset.py @@ -65,7 +65,8 @@ def init_processors(self): def prepare_batch(self, batch): """ - Can be possibly overridden in your child class + Can be possibly overridden in your child class. Not supported w Lightning + trainer Prepare batch for passing to model. Whatever returned from here will be directly passed to model's forward function. Currently moves the batch to diff --git a/mmf/datasets/lightning_datamodule.py b/mmf/datasets/lightning_datamodule.py new file mode 100644 index 000000000..909607452 --- /dev/null +++ b/mmf/datasets/lightning_datamodule.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import pytorch_lightning as pl +from mmf.datasets.multi_dataset_loader import MultiDatasetLoader +from mmf.utils.general import get_batch_size + + +class LightningDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + self.batch_size = get_batch_size() + + self.train_loader = MultiDatasetLoader("train") + self.val_loader = MultiDatasetLoader("val") + self.test_loader = MultiDatasetLoader("test") + + self.train_loader.load(self.config) + self.val_loader.load(self.config) + self.test_loader.load(self.config) + + def train_dataloader(self): + return self.train_loader + + def val_dataloader(self): + return self.val_loader + + def test_dataloader(self): + return self.test_loader diff --git a/mmf/models/base_model.py b/mmf/models/base_model.py index 8912039df..dea75a4b2 100644 --- a/mmf/models/base_model.py +++ b/mmf/models/base_model.py @@ -1,8 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. """ -Models built on top of Pythia need to inherit ``BaseModel`` class and adhere to -some format. To create a model for MMF, follow this quick cheatsheet. +Models built in MMF need to inherit ``BaseModel`` class and adhere to +a fixed format. To create a model for MMF, follow this quick cheatsheet. 1. Inherit ``BaseModel`` class, make sure to call ``super().__init__()`` in your class's ``__init__`` function. @@ -47,21 +47,21 @@ def forward(self, sample_list): from dataclasses import dataclass from typing import List, Optional, Union +import pytorch_lightning as pl from mmf.common.registry import registry -from mmf.common.sample import to_device +from mmf.common.sample import SampleList, to_device from mmf.modules.losses import LossConfig, Losses from mmf.utils.checkpoint import load_pretrained_model from mmf.utils.download import download_pretrained_model from mmf.utils.file_io import PathManager from mmf.utils.general import get_current_device from omegaconf import MISSING, DictConfig, OmegaConf -from torch import nn logger = logging.getLogger(__name__) -class BaseModel(nn.Module): +class BaseModel(pl.LightningModule): """For integration with MMF's trainer, datasets and other features, models needs to inherit this class, call `super`, write a build function, write a forward function taking a ``SampleList`` as input and returning a @@ -84,8 +84,10 @@ def __init__(self, config: Union[DictConfig, Config]): config = OmegaConf.structured(config) self.config = config + self._logged_warning = {"losses_present": False} self._is_pretrained = False + self._is_pl_enabled = False @classmethod def from_params(cls, **kwargs): @@ -95,10 +97,18 @@ def from_params(cls, **kwargs): def is_pretrained(self): return self._is_pretrained + @property + def is_pl_enabled(self): + return self._is_pl_enabled + @is_pretrained.setter def is_pretrained(self, x: bool): self._is_pretrained = x + @is_pl_enabled.setter + def is_pl_enabled(self, x: bool): + self._is_pl_enabled = x + def build(self): """Function to be implemented by the child class, in case they need to build their model separately than ``__init__``. All model related @@ -165,9 +175,63 @@ def forward(self, sample_list, *args, **kwargs): "Forward of the child model class needs to be implemented." ) + def training_step(self, batch, batch_idx, *args, **kwargs): + """Member function of PL modules. Used only when PL enabled. + To be implemented by child class. Takes in a ``SampleList``, + batch_idx and returns back a dict. + + Args: + sample_list (SampleList): SampleList returned by the DataLoader for + current iteration + + Returns: + Dict: Dict containing loss. + """ + batch = self._ensure_sample_list(batch) + output = self(batch) + loss_dict = output["losses"] + output["loss"] = sum(loss.mean() for loss in loss_dict.values()) + return output + + def validation_step(self, batch, batch_idx, *args, **kwargs): + """Member function of PL modules. Used only when PL enabled. + To be implemented by child class. Takes in a ``SampleList``, + batch_idx and returns back a dict. + + Args: + sample_list (SampleList): SampleList returned by the DataLoader for + current iteration + + Returns: + Dict + """ + batch = self._ensure_sample_list(batch) + output = self(batch) + # TODO: @sash Implementation coming soon! (next PR) + return output + + def configure_optimizers(self): + """ Member function of PL modules. Used only when PL enabled.""" + assert self._is_pl_enabled, ( + "configure_optimizers should be only used as a member " + "function of LightningModule when pytorch lightning is enabled." + ) + + from mmf.utils.build import build_lightning_optimizers + + config = registry.get("config") + return build_lightning_optimizers(self, config) + + def _ensure_sample_list(self, batch): + if not isinstance(batch, SampleList): + # Try converting to SampleList + batch = SampleList(batch) + return batch + def __call__(self, sample_list, *args, **kwargs): - # Move to proper device i.e. same as the model before passing - sample_list = to_device(sample_list, get_current_device()) + if not self._is_pl_enabled: + # Move to proper device i.e. same as the model before passing + sample_list = to_device(sample_list, get_current_device()) model_output = super().__call__(sample_list, *args, **kwargs) @@ -175,7 +239,7 @@ def __call__(self, sample_list, *args, **kwargs): if self.is_pretrained: return model_output - # Make sure theat the output from the model is a Mapping + # Make sure that the output from the model is a Mapping assert isinstance( model_output, collections.abc.Mapping ), "A dict must be returned from the forward of the model." diff --git a/mmf/modules/losses.py b/mmf/modules/losses.py index 079202724..dc8195457 100644 --- a/mmf/modules/losses.py +++ b/mmf/modules/losses.py @@ -70,7 +70,7 @@ class Losses(nn.Module): mostly doesn't need to use this class. Attributes: - losses: List containing instanttions of each loss + losses: List containing instantiations of each loss passed in config """ diff --git a/mmf/trainers/callbacks/lr_scheduler.py b/mmf/trainers/callbacks/lr_scheduler.py index 81411d47b..fc09347f8 100644 --- a/mmf/trainers/callbacks/lr_scheduler.py +++ b/mmf/trainers/callbacks/lr_scheduler.py @@ -19,7 +19,7 @@ def __init__(self, config, trainer): self._scheduler = None if self.training_config.lr_scheduler is True: - self._scheduler = build_scheduler(self.trainer.optimizer, self.config) + self._scheduler = build_scheduler(trainer.optimizer, self.config) def on_update_end(self, **kwargs): if self._scheduler is not None: diff --git a/mmf/trainers/core/training_loop.py b/mmf/trainers/core/training_loop.py index f14652d4d..e261193f7 100644 --- a/mmf/trainers/core/training_loop.py +++ b/mmf/trainers/core/training_loop.py @@ -2,8 +2,6 @@ import gc import logging -import math -import warnings from abc import ABC from typing import Any, Dict @@ -11,7 +9,7 @@ from mmf.common.registry import registry from mmf.common.report import Report from mmf.common.sample import to_device -from mmf.utils.general import clip_gradients +from mmf.utils.general import clip_gradients, get_max_updates from torch import Tensor @@ -218,32 +216,13 @@ def _extract_loss(self, report: Dict[str, Any]) -> Tensor: return loss def _calculate_max_updates(self): - max_updates = self.training_config.max_updates - max_epochs = self.training_config.max_epochs - if max_updates is None and max_epochs is None: - raise ValueError("Neither max_updates nor max_epochs is specified.") - - if isinstance( - self.train_loader.current_dataset, torch.utils.data.IterableDataset - ): - warnings.warn( - "max_epochs not supported for Iterable datasets. Falling back " - + "to max_updates." - ) - return max_updates - - if max_updates is not None and max_epochs is not None: - warnings.warn( - "Both max_updates and max_epochs are specified. " - + f"Favoring max_epochs: {max_epochs}" - ) - - if max_epochs is not None: - max_updates = ( - math.ceil( - len(self.train_loader) / self.training_config.update_frequency - ) - * max_epochs - ) + config_max_updates = self.training_config.max_updates + config_max_epochs = self.training_config.max_epochs + max_updates, _ = get_max_updates( + config_max_updates, + config_max_epochs, + self.train_loader, + self.training_config.update_frequency, + ) return max_updates diff --git a/mmf/trainers/lightning_core/__init__.py b/mmf/trainers/lightning_core/__init__.py new file mode 100644 index 000000000..9020c2df2 --- /dev/null +++ b/mmf/trainers/lightning_core/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/mmf/trainers/lightning_core/loop_callback.py b/mmf/trainers/lightning_core/loop_callback.py new file mode 100644 index 000000000..6624339da --- /dev/null +++ b/mmf/trainers/lightning_core/loop_callback.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +from typing import Any, List + +from mmf.common.registry import registry +from mmf.common.sample import SampleList +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks.base import Callback + + +logger = logging.getLogger(__name__) + + +class LightningLoopCallback(Callback): + def __init__(self, lightning_trainer: Any): + super().__init__() + self.lightning_trainer = lightning_trainer + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + registry.register("current_epoch", trainer.current_epoch) + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: List, + batch: SampleList, + batch_idx: int, + dataloader_idx: int, + ): + # prepare the next batch + self.lightning_trainer.data_module.train_loader.change_dataloader() + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule): + # TODO: @sash next PR + pass + + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: List, + batch: SampleList, + batch_idx: int, + dataloader_idx: int, + ): + # prepare the next batch + self.lightning_trainer.data_module.val_loader.change_dataloader() diff --git a/mmf/trainers/lightning_trainer.py b/mmf/trainers/lightning_trainer.py new file mode 100644 index 000000000..7716c3572 --- /dev/null +++ b/mmf/trainers/lightning_trainer.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math + +import omegaconf +from mmf.common import typings as mmf_typings +from mmf.common.registry import registry +from mmf.datasets.lightning_datamodule import LightningDataModule +from mmf.modules.metrics import Metrics +from mmf.trainers.base_trainer import BaseTrainer +from mmf.trainers.lightning_core.loop_callback import LightningLoopCallback +from mmf.utils.build import build_model +from mmf.utils.general import get_max_updates, print_model_parameters +from omegaconf import OmegaConf +from pytorch_lightning import Trainer, seed_everything + + +logger = logging.getLogger(__name__) + + +@registry.register_trainer("lightning") +class LightningTrainer(BaseTrainer): + def __init__(self, config: mmf_typings.DictConfig): + super().__init__(config) + self.trainer = None + + def load(self): + super().load() + self.trainer_config = self.config.trainer.params + self._calculate_max_updates() + self._load_trainer() + + def _load_trainer(self): + lightning_params = self.trainer_config + + with omegaconf.open_dict(lightning_params): + lightning_params.pop("max_steps") + lightning_params.pop("max_epochs") + + lightning_params_dict = OmegaConf.to_container(lightning_params, resolve=True) + self.trainer = Trainer( + logger=False, + callbacks=self._callbacks, + max_steps=self._max_updates, + **lightning_params_dict + ) + + def configure_device(self) -> None: + pass + + def configure_seed(self) -> None: + seed = self.config.training.seed + seed_everything(seed) + + def load_datasets(self) -> None: + logger.info("Loading datasets") + data_module = LightningDataModule(self.config) + self.data_module = data_module + + def load_model(self) -> None: + logger.info("Loading models") + + attributes = self.config.model_config[self.config.model] + if isinstance(attributes, str): + attributes = self.config.model_config[attributes] + with omegaconf.open_dict(attributes): + attributes.model = self.config.model + + self.model = build_model(attributes) + self.model.is_pl_enabled = True + + def load_optimizer(self) -> None: + logger.info("Loading optimizer: noop for lightning") + + def load_metrics(self) -> None: + logger.info("Loading metrics") + metrics = self.config.evaluation.get("metrics", []) + self.metrics = Metrics(metrics) + self.metrics_params = self.metrics.required_params + + def configure_callbacks(self) -> None: + self._callbacks = [LightningLoopCallback(self)] + + def train(self) -> None: + logger.info("===== Model =====") + logger.info(self.model) + print_model_parameters(self.model) + + logger.info("Starting training...") + self.trainer.fit(self.model, self.data_module) + + def inference(self) -> None: + logger.info("Starting inference...") + # TODO: @sash coming soon + pass + + def _calculate_max_updates(self) -> None: + self._max_updates = self.trainer_config.max_steps + self._max_epochs = self.trainer_config.max_epochs + if self._max_updates is None and self._max_epochs is None: + raise ValueError("Neither max_updates nor max_epochs is specified.") + + train_loader = self.data_module.train_loader + self._max_updates, max_epochs = get_max_updates( + self._max_updates, + self._max_epochs, + train_loader, + self.trainer_config.accumulate_grad_batches, + ) + self._max_epochs = math.ceil(max_epochs) + return self._max_updates diff --git a/mmf/utils/build.py b/mmf/utils/build.py index 2ccb693ad..dec505327 100644 --- a/mmf/utils/build.py +++ b/mmf/utils/build.py @@ -260,6 +260,19 @@ def build_optimizer(model, config): return optimizer +def build_lightning_optimizers(model, config): + optimizer = build_optimizer(model, config) + + if config.training.lr_scheduler: + lr_scheduler = build_scheduler(optimizer, config) + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}, + } + else: + return optimizer + + def build_scheduler(optimizer, config): scheduler_config = config.get("scheduler", {}) diff --git a/mmf/utils/general.py b/mmf/utils/general.py index 597245cbe..a093fc286 100644 --- a/mmf/utils/general.py +++ b/mmf/utils/general.py @@ -3,7 +3,9 @@ import collections import gc import logging +import math import os +import warnings from bisect import bisect import torch @@ -289,6 +291,33 @@ def get_sizes_list(dim, chunks): return sizes_list +def get_max_updates(config_max_updates, config_max_epochs, train_loader, update_freq): + if config_max_updates is None and config_max_epochs is None: + raise ValueError("Neither max_updates nor max_epochs is specified.") + + if isinstance(train_loader.current_dataset, torch.utils.data.IterableDataset): + warnings.warn( + "max_epochs not supported for Iterable datasets. Falling back " + + "to max_updates." + ) + return config_max_updates, config_max_epochs + + if config_max_updates is not None and config_max_epochs is not None: + warnings.warn( + "Both max_updates and max_epochs are specified. " + + f"Favoring max_epochs: {config_max_epochs}" + ) + + if config_max_epochs is not None: + max_updates = math.ceil(len(train_loader) / update_freq) * config_max_epochs + max_epochs = config_max_epochs + else: + max_updates = config_max_updates + max_epochs = max_updates / len(train_loader) + + return max_updates, max_epochs + + def get_chunks(x, sizes): out = [] begin = 0 diff --git a/pyproject.toml b/pyproject.toml index 98239b795..7fc0aa67e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ known_third_party = [ "PIL", "cv2", "demjson", "fairscale", "h5py", "lib", "lmdb", "maskrcnn_benchmark", "mmf", "numpy", "omegaconf", "packaging", "pycocoevalcap", "pytorch_sphinx_theme", "recommonmark", "requests", "setuptools", "sklearn", "termcolor", "tests", "torch", - "torchtext", "torchvision", "tqdm", "transformers" + "torchtext", "torchvision", "tqdm", "transformers", "pytorch_lightning" ] skip_glob = "**/build/**,website/**" combine_as_imports = true diff --git a/requirements.txt b/requirements.txt index b2a4b9208..1eebeffe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ omegaconf==2.0.6 lmdb==0.98 termcolor==1.1.0 iopath==0.1.3 +pytorch_lightning==1.1.6 diff --git a/tests/test_utils.py b/tests/test_utils.py index 261eb8b71..1d2b442da 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,6 +9,7 @@ import tempfile import unittest +import pytorch_lightning as pl import torch from mmf.common.sample import Sample, SampleList from mmf.utils.general import get_current_device @@ -79,8 +80,6 @@ def compare_state_dicts(a, b): def build_random_sample_list(): - from mmf.common.sample import Sample, SampleList - first = Sample() first.x = random.randint(0, 100) first.y = torch.rand((5, 4)) @@ -121,10 +120,34 @@ def __init__(self, size): self.linear = torch.nn.Linear(size, 1) def forward(self, prepared_batch): + input_sample = SampleList(prepared_batch) batch = prepared_batch[DATA_ITEM_KEY] output = self.linear(batch) - loss = torch.nn.MSELoss()(output, batch) - return {"losses": {"loss": loss}, "logits": output} + loss = torch.nn.MSELoss()(-1 * output, batch) + return {"losses": {"loss": loss}, "logits": output, "input_batch": input_sample} + + +class SimpleLightningModel(pl.LightningModule): + def __init__(self, size, config=None): + super().__init__() + self.model = SimpleModel(size) + self.config = config + + def forward(self, prepared_batch): + return self.model(prepared_batch) + + def training_step(self, batch, batch_idx, *args, **kwargs): + output = self(batch) + output["loss"] = output["losses"]["loss"] + return output + + def configure_optimizers(self): + if self.config is None: + return torch.optim.Adam(self.parameters(), lr=0.01) + else: + from mmf.utils.build import build_lightning_optimizers + + return build_lightning_optimizers(self, self.config) def assertModulesEqual(mod1, mod2): diff --git a/tests/trainers/lightning/__init__.py b/tests/trainers/lightning/__init__.py new file mode 100644 index 000000000..9020c2df2 --- /dev/null +++ b/tests/trainers/lightning/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/tests/trainers/lightning/lightning_trainer_mock.py b/tests/trainers/lightning/lightning_trainer_mock.py new file mode 100644 index 000000000..ad12a5c74 --- /dev/null +++ b/tests/trainers/lightning/lightning_trainer_mock.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + + +from unittest.mock import MagicMock + +import torch +from mmf.trainers.lightning_trainer import LightningTrainer +from tests.test_utils import NumbersDataset + + +class LightningTrainerMock(LightningTrainer): + def __init__( + self, + config, + max_steps, + max_epochs=None, + callback=None, + num_data_size=100, + batch_size=1, + accumulate_grad_batches=1, + lr_scheduler=False, + gradient_clip_val=0.0, + precision=32, + ): + self.config = config + self._callbacks = None + if callback: + self._callbacks = [callback] + + # data + self.data_module = MagicMock() + dataset = NumbersDataset(num_data_size) + self.data_module.train_loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=1, + drop_last=False, + ) + self.data_module.train_loader.current_dataset = MagicMock(return_value=dataset) + + # settings + trainer_config = self.config.trainer.params + trainer_config.accumulate_grad_batches = accumulate_grad_batches + trainer_config.precision = precision + trainer_config.max_steps = max_steps + trainer_config.max_epochs = max_epochs + trainer_config.gradient_clip_val = gradient_clip_val + trainer_config.precision = precision + + self.trainer_config = trainer_config + self.config.training.lr_scheduler = lr_scheduler diff --git a/tests/trainers/lightning/test_grad_accumulate.py b/tests/trainers/lightning/test_grad_accumulate.py new file mode 100644 index 000000000..089721f1b --- /dev/null +++ b/tests/trainers/lightning/test_grad_accumulate.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +import torch +from tests.trainers.lightning.test_utils import get_lightning_trainer + + +class TestLightningTrainerGradAccumulate(unittest.TestCase): + def test_grad_accumulate(self): + trainer1 = get_lightning_trainer( + accumulate_grad_batches=2, max_steps=2, batch_size=3 + ) + trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader) + + trainer2 = get_lightning_trainer( + accumulate_grad_batches=1, max_steps=2, batch_size=6 + ) + trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader) + + for param1, param2 in zip( + trainer1.model.parameters(), trainer2.model.parameters() + ): + self.assertTrue(torch.allclose(param1, param2)) diff --git a/tests/trainers/lightning/test_grad_clipping.py b/tests/trainers/lightning/test_grad_clipping.py new file mode 100644 index 000000000..160438a65 --- /dev/null +++ b/tests/trainers/lightning/test_grad_clipping.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +import torch +from mmf.utils.general import clip_gradients +from pytorch_lightning.callbacks.base import Callback +from tests.trainers.lightning.test_utils import get_lightning_trainer, get_mmf_trainer + + +class TestLightningTrainerGradClipping(unittest.TestCase, Callback): + def setUp(self): + self.mmf_grads = [] + self.lightning_grads = [] + + self.grad_clip_magnitude = 0.15 + self.grad_clipping_config = { + "max_grad_l2_norm": self.grad_clip_magnitude, + "clip_norm_mode": "all", + } + + def test_grad_clipping_and_parity_to_mmf(self): + mmf_trainer = get_mmf_trainer( + max_updates=5, + max_epochs=None, + grad_clipping_config=self.grad_clipping_config, + ) + + def _finish_update(): + clip_gradients( + mmf_trainer.model, mmf_trainer.num_updates, None, mmf_trainer.config + ) + for param in mmf_trainer.model.parameters(): + mmf_grad = torch.clone(param.grad).detach().item() + self.mmf_grads.append(mmf_grad) + + mmf_trainer.scaler.step(mmf_trainer.optimizer) + mmf_trainer.scaler.update() + mmf_trainer.num_updates += 1 + + mmf_trainer._finish_update = _finish_update + mmf_trainer.training_loop() + + trainer = get_lightning_trainer( + max_steps=5, + max_epochs=None, + gradient_clip_val=self.grad_clip_magnitude, + callback=self, + ) + trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) + + def on_after_backward(self, trainer, pl_module): + for param in pl_module.parameters(): + self.assertLessEqual(param.grad, self.grad_clip_magnitude) + + for lightning_param in pl_module.parameters(): + lightning_grad = torch.clone(lightning_param.grad).detach().item() + self.lightning_grads.append(lightning_grad) + + def on_train_end(self, trainer, pl_module): + for lightning_grad, mmf_grad in zip(self.lightning_grads, self.mmf_grads): + self.assertEqual(lightning_grad, mmf_grad) diff --git a/tests/trainers/lightning/test_loop_conditions.py b/tests/trainers/lightning/test_loop_conditions.py new file mode 100644 index 000000000..099529f3c --- /dev/null +++ b/tests/trainers/lightning/test_loop_conditions.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +from tests.trainers.lightning.test_utils import get_lightning_trainer + + +class TestLightningTrainer(unittest.TestCase): + def test_epoch_over_updates(self): + trainer = get_lightning_trainer(max_steps=2, max_epochs=0.04) + self.assertEqual(trainer._max_updates, 4) + + self._check_values(trainer, 0, 0) + trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) + self._check_values(trainer, 4, 0) + + def test_fractional_epoch(self): + trainer = get_lightning_trainer(max_steps=None, max_epochs=0.04) + self.assertEqual(trainer._max_updates, 4) + + self._check_values(trainer, 0, 0) + trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) + self._check_values(trainer, 4, 0) + + def test_updates(self): + trainer = get_lightning_trainer(max_steps=2, max_epochs=None) + self.assertEqual(trainer._max_updates, 2) + + self._check_values(trainer, 0, 0) + trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) + self._check_values(trainer, 2, 0) + + def _check_values(self, trainer, current_iteration, current_epoch): + self.assertEqual(trainer.trainer.global_step, current_iteration) + self.assertEqual(trainer.trainer.current_epoch, current_epoch) diff --git a/tests/trainers/lightning/test_loss.py b/tests/trainers/lightning/test_loss.py new file mode 100644 index 000000000..35b06794e --- /dev/null +++ b/tests/trainers/lightning/test_loss.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +from mmf.common.report import Report +from pytorch_lightning.callbacks.base import Callback +from tests.trainers.lightning.test_utils import get_lightning_trainer, get_mmf_trainer + + +class TestLightningTrainerLoss(unittest.TestCase, Callback): + def setUp(self): + self.lightning_losses = [] + self.mmf_losses = [] + + def test_loss_computation_parity_with_mmf_trainer(self): + # compute mmf_trainer training losses + def _on_update_end(report, meter, should_log): + self.mmf_losses.append(report["losses"]["loss"].item()) + + mmf_trainer = get_mmf_trainer( + max_updates=5, max_epochs=None, on_update_end_fn=_on_update_end + ) + mmf_trainer.training_loop() + + # compute lightning_trainer training losses + trainer = get_lightning_trainer(callback=self, max_steps=5) + trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) + + def on_train_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + output = outputs[0][0]["extra"] + report = Report(output["input_batch"], output) + self.lightning_losses.append(report["losses"]["loss"].item()) + + def on_train_end(self, trainer, pl_module): + for lightning_loss, mmf_loss in zip(self.lightning_losses, self.mmf_losses): + self.assertEqual(lightning_loss, mmf_loss) diff --git a/tests/trainers/lightning/test_lr_schedule.py b/tests/trainers/lightning/test_lr_schedule.py new file mode 100644 index 000000000..cee20fd0e --- /dev/null +++ b/tests/trainers/lightning/test_lr_schedule.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +import torch +from tests.trainers.lightning.test_utils import ( + get_lightning_trainer, + get_mmf_trainer, + get_trainer_config, +) + + +class TestLightningTrainerLRSchedule(unittest.TestCase): + def test_lr_schedule(self): + # note, be aware some of the logic also is in the SimpleLightningModel + trainer1 = get_lightning_trainer(max_steps=8, lr_scheduler=True) + trainer1.trainer.fit(trainer1.model, trainer1.data_module.train_loader) + + trainer2 = get_lightning_trainer(max_steps=8) + trainer2.trainer.fit(trainer2.model, trainer2.data_module.train_loader) + + last_model_param1 = list(trainer1.model.parameters())[-1] + last_model_param2 = list(trainer2.model.parameters())[-1] + self.assertFalse(torch.allclose(last_model_param1, last_model_param2)) + + def test_lr_schedule_compared_to_mmf_is_same(self): + trainer_config = get_trainer_config() + mmf_trainer = get_mmf_trainer( + max_updates=8, max_epochs=None, scheduler_config=trainer_config.scheduler + ) + mmf_trainer.training_loop() + + trainer = get_lightning_trainer(max_steps=8, lr_scheduler=True) + trainer.trainer.fit(trainer.model, trainer.data_module.train_loader) + + mmf_trainer.model.to(trainer.model.device) + last_model_param1 = list(mmf_trainer.model.parameters())[-1] + last_model_param2 = list(trainer.model.parameters())[-1] + self.assertTrue(torch.allclose(last_model_param1, last_model_param2)) diff --git a/tests/trainers/lightning/test_utils.py b/tests/trainers/lightning/test_utils.py new file mode 100644 index 000000000..11b3f503e --- /dev/null +++ b/tests/trainers/lightning/test_utils.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import torch +from mmf.utils.build import build_optimizer +from omegaconf import OmegaConf +from tests.test_utils import SimpleLightningModel, SimpleModel +from tests.trainers.lightning.lightning_trainer_mock import LightningTrainerMock +from tests.trainers.test_training_loop import TrainerTrainingLoopMock + + +def get_trainer_config(): + return OmegaConf.create( + { + "distributed": {}, + "run_type": "train", + "training": { + "trainer": "lightning", + "detect_anomaly": False, + "evaluation_interval": 4, + "log_interval": 2, + "update_frequency": 1, + "fp16": False, + "lr_scheduler": False, + }, + "optimizer": {"type": "adam_w", "params": {"lr": 5e-5, "eps": 1e-8}}, + "scheduler": { + "type": "warmup_linear", + "params": {"num_warmup_steps": 8, "num_training_steps": 8}, + }, + "trainer": { + "type": "lightning", + "params": { + "gpus": 0 if torch.cuda.is_available() else None, + "num_nodes": 1, + "precision": 32, + "deterministic": True, + "benchmark": False, + "gradient_clip_val": 0.0, + "val_check_interval": 4, + "log_every_n_steps": 2, + "checkpoint_callback": False, + }, + }, + } + ) + + +def get_lightning_trainer( + max_steps, + max_epochs=None, + batch_size=1, + model_size=1, + accumulate_grad_batches=1, + callback=None, + lr_scheduler=False, + gradient_clip_val=0.0, + precision=32, +): + torch.random.manual_seed(2) + trainer = LightningTrainerMock( + config=get_trainer_config(), + max_steps=max_steps, + max_epochs=max_epochs, + callback=callback, + batch_size=batch_size, + accumulate_grad_batches=accumulate_grad_batches, + lr_scheduler=lr_scheduler, + gradient_clip_val=gradient_clip_val, + precision=precision, + ) + trainer.model = SimpleLightningModel(model_size, config=trainer.config) + trainer.model.train() + prepare_lightning_trainer(trainer) + return trainer + + +def get_mmf_trainer( + model_size=1, + num_data_size=100, + max_updates=5, + max_epochs=None, + on_update_end_fn=None, + fp16=False, + scheduler_config=None, + grad_clipping_config=None, +): + torch.random.manual_seed(2) + model = SimpleModel(model_size) + model.train() + trainer_config = get_trainer_config() + optimizer = build_optimizer(model, trainer_config) + trainer = TrainerTrainingLoopMock( + num_data_size, + max_updates, + max_epochs, + config=trainer_config, + optimizer=optimizer, + on_update_end_fn=on_update_end_fn, + fp16=fp16, + scheduler_config=scheduler_config, + grad_clipping_config=grad_clipping_config, + ) + model.to(trainer.device) + trainer.model = model + return trainer + + +def prepare_lightning_trainer(trainer): + trainer.configure_device() + trainer._calculate_max_updates() + trainer._load_trainer() diff --git a/tests/trainers/test_fp16.py b/tests/trainers/test_fp16.py index 31d6233a5..550506540 100644 --- a/tests/trainers/test_fp16.py +++ b/tests/trainers/test_fp16.py @@ -3,7 +3,6 @@ import unittest import torch -from mmf.trainers.mmf_trainer import MMFTrainer from tests.test_utils import SimpleModel, skip_if_no_cuda from tests.trainers.test_training_loop import TrainerTrainingLoopMock @@ -26,7 +25,7 @@ def forward(self, sample_list): return model_output -class MMFTrainerMock(TrainerTrainingLoopMock, MMFTrainer): +class MMFTrainerMock(TrainerTrainingLoopMock): def __init__( self, num_train_data, max_updates, max_epochs, device="cuda", fp16_model=False ): diff --git a/tests/trainers/test_training_loop.py b/tests/trainers/test_training_loop.py index 0ec7dae05..345fa0c33 100644 --- a/tests/trainers/test_training_loop.py +++ b/tests/trainers/test_training_loop.py @@ -6,13 +6,13 @@ import torch from mmf.common.meter import Meter from mmf.common.sample import SampleList -from mmf.trainers.core.profiling import TrainerProfilingMixin -from mmf.trainers.core.training_loop import TrainerTrainingLoopMixin +from mmf.trainers.callbacks.lr_scheduler import LRSchedulerCallback +from mmf.trainers.mmf_trainer import MMFTrainer from omegaconf import OmegaConf from tests.test_utils import NumbersDataset, SimpleModel -class TrainerTrainingLoopMock(TrainerTrainingLoopMixin, TrainerProfilingMixin): +class TrainerTrainingLoopMock(MMFTrainer): def __init__( self, num_train_data, @@ -24,6 +24,8 @@ def __init__( batch_size=1, fp16=False, on_update_end_fn=None, + scheduler_config=None, + grad_clipping_config=None, ): if config is None: self.training_config = OmegaConf.create( @@ -37,6 +39,7 @@ def __init__( ) else: self.training_config = config.training + self.config = config if max_updates is not None: self.training_config["max_updates"] = max_updates @@ -60,6 +63,24 @@ def __init__( else: self.optimizer = optimizer + if scheduler_config: + config.training.lr_scheduler = True + config.scheduler = scheduler_config + self.lr_scheduler_callback = LRSchedulerCallback(config, self) + self.callbacks.append(self.lr_scheduler_callback) + on_update_end_fn = ( + on_update_end_fn + if on_update_end_fn + else self.lr_scheduler_callback.on_update_end + ) + + if grad_clipping_config: + self.training_config.clip_gradients = True + self.training_config.max_grad_l2_norm = grad_clipping_config[ + "max_grad_l2_norm" + ] + self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"] + dataset = NumbersDataset(num_train_data) self.train_loader = torch.utils.data.DataLoader( dataset=dataset, @@ -113,9 +134,9 @@ def test_update_frequency_reporting(self): def _on_update_end(report, meter, should_log): # the losses here should be the sum of two losses in # iteration 0 and iteration 1 (both constitute update 0). - # Here iter 1 loss: 0.2684, iter 2 loss: 2.4167 + # Here iter 1 loss: 0.2599, iter 2 loss: 4.2090 loss = report.losses["loss"].detach().cpu().item() - self.assertAlmostEqual(loss, 2.6852, 4) + self.assertAlmostEqual(loss, 4.4688, 4) self._train_with_condition( num_train_data=100,