diff --git a/mmf/datasets/base_dataset_builder.py b/mmf/datasets/base_dataset_builder.py index d69b04511..2ae7e7704 100644 --- a/mmf/datasets/base_dataset_builder.py +++ b/mmf/datasets/base_dataset_builder.py @@ -40,6 +40,7 @@ def build(self, config, dataset_type, *args, **kwargs): import pytorch_lightning as pl from mmf.utils.build import build_dataloader_and_sampler from mmf.utils.distributed import is_master, synchronize +from omegaconf import DictConfig from torch.utils.data import Dataset @@ -72,13 +73,26 @@ def dataset_name(self, dataset_name): self._dataset_name = dataset_name def prepare_data(self, config, *args, **kwargs): + """ + NOTE: The caller to this function should only call this on master process + in a distributed settings so that downloads and build only happen + on master process and others can just load it. Make sure to call + synchronize afterwards to bring all processes in sync. + + Lightning automatically wraps datamodule in a way that it is only + called on a master node, but for extra precaution as lightning + can introduce bugs, we should always call this under master process + with extra checks on our sides as well. + """ self.config = config self.build_dataset(config) - def setup(self, stage: Optional[str] = None): - self.train_dataset = self.load_dataset(self.config, "train") - self.val_dataset = self.load_dataset(self.config, "val") - self.test_dataset = self.load_dataset(self.config, "test") + def setup(self, stage: Optional[str] = None, config: Optional[DictConfig] = None): + if config is None: + config = self.config + self.train_dataset = self.load_dataset(config, "train") + self.val_dataset = self.load_dataset(config, "val") + self.test_dataset = self.load_dataset(config, "test") @property def train_dataset(self) -> Optional[Dataset]: @@ -110,6 +124,11 @@ def build_dataset(self, config, dataset_type="train", *args, **kwargs): time when it is not available. This internally calls 'build' function. Override that function in your child class. + NOTE: The caller to this function should only call this on master process + in a distributed settings so that downloads and build only happen + on master process and others can just load it. Make sure to call + synchronize afterwards to bring all processes in sync. + Args: config (DictConfig): Configuration of this dataset loaded from config. @@ -119,10 +138,7 @@ def build_dataset(self, config, dataset_type="train", *args, **kwargs): DO NOT OVERRIDE in child class. Instead override ``build``. """ - # Only build in main process, so none of the others have to build - if is_master(): - self.build(config, dataset_type, *args, **kwargs) - synchronize() + self.build(config, dataset_type, *args, **kwargs) def load_dataset(self, config, dataset_type="train", *args, **kwargs): """Main load function use by MMF. This will internally call ``load`` diff --git a/mmf/datasets/mmf_dataset_builder.py b/mmf/datasets/mmf_dataset_builder.py index 97cdca982..1beb6497d 100644 --- a/mmf/datasets/mmf_dataset_builder.py +++ b/mmf/datasets/mmf_dataset_builder.py @@ -61,6 +61,7 @@ def set_dataset_class(self, dataset_cls): self.dataset_class = dataset_cls def build(self, config, dataset_type="train", *args, **kwargs): + self.config = config requirements = config.get("zoo_requirements", []) if len(requirements) == 0: diff --git a/mmf/utils/build.py b/mmf/utils/build.py index 71d14c2dd..630c149e2 100644 --- a/mmf/utils/build.py +++ b/mmf/utils/build.py @@ -87,7 +87,6 @@ def build_model( model = model_class(config) if hasattr(model, "build"): - model.load_requirements() """ Model build involves checkpoint loading If the checkpoint is not available the underlying methods try to download it. @@ -99,6 +98,7 @@ def build_model( using already downloaded checkpoint. """ if is_master(): + model.load_requirements() model.build() synchronize() else: @@ -207,8 +207,12 @@ def build_multiple_datamodules( + " in config. Proceeding with empty config." ) dataset_config = OmegaConf.create() - datamodule_instance.prepare_data(dataset_config) - datamodule_instance.setup() + + if is_master(): + datamodule_instance.prepare_data(dataset_config) + + synchronize() + datamodule_instance.setup(config=dataset_config) if hasattr(datamodule_instance, "update_registry_for_model"): datamodule_instance.update_registry_for_model(dataset_config) datamodules[dataset] = datamodule_instance diff --git a/mmf/utils/download.py b/mmf/utils/download.py index fa4519259..52605b2a9 100644 --- a/mmf/utils/download.py +++ b/mmf/utils/download.py @@ -22,7 +22,6 @@ import numpy as np import requests import tqdm -from mmf.utils.distributed import is_master, synchronize from mmf.utils.file_io import PathManager from mmf.utils.general import get_absolute_path from PIL import Image @@ -376,9 +375,7 @@ def download_pretrained_model(model_name, *args, **kwargs): version = model_config.version resources = model_config.resources - if is_master(): - download_resources(resources, download_path, version) - synchronize() + download_resources(resources, download_path, version) return download_path