Skip to content

Commit

Permalink
[fix] multinode jobs after recent lightning update (#921)
Browse files Browse the repository at this point in the history
Summary:
After Sasha's update of pytorch lightning on MMF master, it broke MMF codebase for multinode job. The root problem to PR Lightning-AI/pytorch-lightning#6802. The assumption that SLURM_PROCID points to worker rank is wrong as some frameworks launch their own processes later using multiprocessing spawn and have ntasks_per_node=1 set. This means that first node will have procid = 0, second node will have procid = 1 set and so on. Now, since this is used in prepare_data masking in LightningDataModule, this leads to it running on all workers on first node and thus causing inconsistencies. Now, this leads to prepare_data being called on all workers on first node instead of rank zero. Specifically, the barrier call in prepare_data, is called on first node workers but not on others leading to block later on.

This PR fixes this by ensuring on our side that we only call prepare_data on rank zero. Furthermore, this can cause further confusion, we remove sync barrier calls from download as well. Users are now supposed to handle is_master checks on their own.

Pull Request resolved: #921

Test Plan: Tested in multinode settings.

Reviewed By: vedanuj

Differential Revision: D28156855

Pulled By: apsdehal

fbshipit-source-id: 4e0dd5317e15153f558d34c6951a89299602454f
  • Loading branch information
apsdehal authored and facebook-github-bot committed May 3, 2021
1 parent 1cc6b08 commit 8a01258
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
32 changes: 24 additions & 8 deletions mmf/datasets/base_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def build(self, config, dataset_type, *args, **kwargs):
import pytorch_lightning as pl
from mmf.utils.build import build_dataloader_and_sampler
from mmf.utils.distributed import is_master, synchronize
from omegaconf import DictConfig
from torch.utils.data import Dataset


Expand Down Expand Up @@ -72,13 +73,26 @@ def dataset_name(self, dataset_name):
self._dataset_name = dataset_name

def prepare_data(self, config, *args, **kwargs):
"""
NOTE: The caller to this function should only call this on master process
in a distributed settings so that downloads and build only happen
on master process and others can just load it. Make sure to call
synchronize afterwards to bring all processes in sync.
Lightning automatically wraps datamodule in a way that it is only
called on a master node, but for extra precaution as lightning
can introduce bugs, we should always call this under master process
with extra checks on our sides as well.
"""
self.config = config
self.build_dataset(config)

def setup(self, stage: Optional[str] = None):
self.train_dataset = self.load_dataset(self.config, "train")
self.val_dataset = self.load_dataset(self.config, "val")
self.test_dataset = self.load_dataset(self.config, "test")
def setup(self, stage: Optional[str] = None, config: Optional[DictConfig] = None):
if config is None:
config = self.config
self.train_dataset = self.load_dataset(config, "train")
self.val_dataset = self.load_dataset(config, "val")
self.test_dataset = self.load_dataset(config, "test")

@property
def train_dataset(self) -> Optional[Dataset]:
Expand Down Expand Up @@ -110,6 +124,11 @@ def build_dataset(self, config, dataset_type="train", *args, **kwargs):
time when it is not available. This internally calls 'build' function.
Override that function in your child class.
NOTE: The caller to this function should only call this on master process
in a distributed settings so that downloads and build only happen
on master process and others can just load it. Make sure to call
synchronize afterwards to bring all processes in sync.
Args:
config (DictConfig): Configuration of this dataset loaded from
config.
Expand All @@ -119,10 +138,7 @@ def build_dataset(self, config, dataset_type="train", *args, **kwargs):
DO NOT OVERRIDE in child class. Instead override ``build``.
"""
# Only build in main process, so none of the others have to build
if is_master():
self.build(config, dataset_type, *args, **kwargs)
synchronize()
self.build(config, dataset_type, *args, **kwargs)

def load_dataset(self, config, dataset_type="train", *args, **kwargs):
"""Main load function use by MMF. This will internally call ``load``
Expand Down
1 change: 1 addition & 0 deletions mmf/datasets/mmf_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def set_dataset_class(self, dataset_cls):
self.dataset_class = dataset_cls

def build(self, config, dataset_type="train", *args, **kwargs):
self.config = config
requirements = config.get("zoo_requirements", [])

if len(requirements) == 0:
Expand Down
10 changes: 7 additions & 3 deletions mmf/utils/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def build_model(
model = model_class(config)

if hasattr(model, "build"):
model.load_requirements()
""" Model build involves checkpoint loading
If the checkpoint is not available the underlying
methods try to download it.
Expand All @@ -99,6 +98,7 @@ def build_model(
using already downloaded checkpoint.
"""
if is_master():
model.load_requirements()
model.build()
synchronize()
else:
Expand Down Expand Up @@ -207,8 +207,12 @@ def build_multiple_datamodules(
+ " in config. Proceeding with empty config."
)
dataset_config = OmegaConf.create()
datamodule_instance.prepare_data(dataset_config)
datamodule_instance.setup()

if is_master():
datamodule_instance.prepare_data(dataset_config)

synchronize()
datamodule_instance.setup(config=dataset_config)
if hasattr(datamodule_instance, "update_registry_for_model"):
datamodule_instance.update_registry_for_model(dataset_config)
datamodules[dataset] = datamodule_instance
Expand Down
5 changes: 1 addition & 4 deletions mmf/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import requests
import tqdm
from mmf.utils.distributed import is_master, synchronize
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_absolute_path
from PIL import Image
Expand Down Expand Up @@ -376,9 +375,7 @@ def download_pretrained_model(model_name, *args, **kwargs):
version = model_config.version
resources = model_config.resources

if is_master():
download_resources(resources, download_path, version)
synchronize()
download_resources(resources, download_path, version)

return download_path

Expand Down

0 comments on commit 8a01258

Please sign in to comment.