Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer failing silently for multi-node processing #8993

Closed
MartaTintore opened this issue Aug 19, 2021 · 3 comments · Fixed by #18292
Closed

Trainer failing silently for multi-node processing #8993

MartaTintore opened this issue Aug 19, 2021 · 3 comments · Fixed by #18292
Assignees
Labels
bug Something isn't working environment: slurm help wanted Open to be worked on priority: 1 Medium priority task

Comments

@MartaTintore
Copy link

MartaTintore commented Aug 19, 2021

🐛 Bug

When I run the Trainer.fit command with multiple nodes the program fails silently and hangs forever.
If I specify 1 node with multiple GPUs, the process runs. But as soon as I specify 2 or more, the process just hangs indefinitely.

To Reproduce

class DataModuleFromConfig(pl.LightningDataModule):
    def __init__(self, batch_size, train=None, validation=None, test=None,
                 wrap=False, num_workers=None, distributed=False):
        super().__init__()
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size*2
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = self._val_dataloader
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = self._test_dataloader
        self.wrap = wrap

    def _train_dataloader(self):
        return DataLoader(self.datasets["train"], 
                          batch_size=self.batch_size,
                        #   num_workers=self.num_workers, 
                          shuffle=True
                          )

    def _val_dataloader(self):
        return DataLoader(self.datasets["validation"],
                          batch_size=self.batch_size,
                        #   num_workers=self.num_workers
                          )

    def _test_dataloader(self):
        return DataLoader(self.datasets["test"], 
                          batch_size=self.batch_size,
                        #   num_workers=self.num_workers
                          )

def train(opt, unknown, now, logdir, nowname):

    ckptdir = os.path.join(logdir, "checkpoints")
    cfgdir = os.path.join(logdir, "configs")
    seed_everything(opt.seed)

    try:
        # init and save configs
        configs = [OmegaConf.load(cfg) for cfg in opt.base]
        print('Config from:', opt.base)
        cli = OmegaConf.from_dotlist(unknown)
        config = OmegaConf.merge(*configs, cli)
        lightning_config = config.pop("lightning", OmegaConf.create())
        # merge trainer cli with config
        trainer_config = lightning_config.get("trainer", OmegaConf.create())
        # default to ddp
        trainer_config["accelerator"] = "ddp"
        opt.ngpus = opt.ngpus if torch.cuda.is_available() else 0
        trainer_config["gpus"] = opt.ngpus if torch.cuda.is_available() else 0
        trainer_config["num_nodes"] = opt.nodes
        if not torch.cuda.is_available():
            del trainer_config["accelerator"]
            cpu = True
        else:
            gpuinfo = trainer_config["gpus"]
            print(f"Running on GPUs {gpuinfo}")
            cpu = False
        trainer_opt = argparse.Namespace(**trainer_config)
        lightning_config.trainer = trainer_config
        print(lightning_config)

        # trainer and callbacks
        trainer_kwargs = dict()
        
        # add callback which sets up log directory
        default_callbacks_cfg = {
            "checkpointing": {
                "target": "pytorch_lightning.callbacks.ModelCheckpoint",
                "params": {
                    "dirpath": ckptdir,
                    "filename": "{epoch}",
                    "verbose": True,
                }
            },
            "learning_rate_logger": {
                "target": "pytorch_lightning.callbacks.LearningRateMonitor",
                "params": {
                    "logging_interval": "step",
                }
            },
        }
        
        # Save best models
        if config.model.monitor is not None:
            print(f"Monitoring {config.model.monitor} as checkpoint metric.")
            default_callbacks_cfg["checkpointing"]["params"]["monitor"] = config.model.monitor
            default_callbacks_cfg["checkpointing"]["params"]["mode"] = "min"
            default_callbacks_cfg["checkpointing"]["params"]["save_top_k"] = config.model.save_top_k
            default_callbacks_cfg["checkpointing"]["params"]["every_n_val_epochs"] = 1
            
        callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
        callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
        trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
        trainer_kwargs["default_root_dir"] = logdir
        trainer_kwargs["weights_save_path"] = logdir
        trainer_kwargs["max_epochs"] = config.model.max_epochs

        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
        print('gpus', trainer.gpus, 'nodes', trainer.num_nodes, \
                'processes', trainer.num_processes, 'devices', trainer.devices)
        
        data = instantiate_from_config(config.data)
        data.prepare_data()
        data.setup()

        # configure learning rate
        bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
        total_gpu = opt.ngpus * opt.nodes if opt.ngpus > 0 else 1
        accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches

        # Initialize model
        print('Creating model...')
        model = instantiate_from_config(config.model)

        model.learning_rate = accumulate_grad_batches * total_gpu * bs * base_lr
        print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (total_num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
            model.learning_rate, accumulate_grad_batches, total_gpu, bs, base_lr))

        # run
        if opt.train:
            try:
                data_train = data.train_dataloader()
                data_val = data.val_dataloader()
                start_time = time.time()
                trainer.fit(model, data_train, data_val)
                training_time = datetime.timedelta(seconds=int(time.time() - start_time))

def main():

    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)
    opt, unknown = parser.parse_known_args()
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")

     if opt.name:
            name = "_" + opt.name
     elif opt.base:
            cfg_fname = os.path.split(opt.base[0])[-1]
            cfg_name = os.path.splitext(cfg_fname)[0]
            name = "_" + cfg_name
     else:
            name = ""
     nowname = now + name + opt.postfix
     logdir = os.path.join("logs", nowname)
     os.makedirs(logdir, exist_ok=True)

    # executor is the submission interface (logs are dumped in the folder)
    executor = submitit.AutoExecutor(folder=logdir)
    num_gpus_per_node = opt.ngpus
    nodes = opt.nodes
    
    executor.update_parameters(
        mem_gb=80 * num_gpus_per_node,
        timeout_min=1500,
        slurm_partition={NAME},
        gpus_per_node=num_gpus_per_node, 
        tasks_per_node=num_gpus_per_node,
        cpus_per_task=10, 
        nodes=nodes,
        slurm_constraint="volta32gb",
    )

    # Run on cluster
    job = executor.submit(train, opt, unknown, now, logdir, nowname)
   
if __name__ == "__main__":
    main()

Expected behavior

Training the model on multiple nodes.

Environment

  • CUDA:
    - GPU:
    - Quadro GP100
    - Quadro GP100
    - available: True
    - version: 10.2

  • Packages:
    - numpy: 1.19.2
    - pyTorch_debug: False
    - pyTorch_version: 1.9.0
    - pytorch-lightning: 1.4.2
    - tqdm: 4.61.0

  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.5
    - version: Training accuracy #57-Ubuntu SMP Thu Oct 15 10:57:00 UTC 2020

@MartaTintore MartaTintore added bug Something isn't working help wanted Open to be worked on labels Aug 19, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Aug 19, 2021

Hey @MartaTintore
There are a lot of unknowns in your posted code. Let's take a step back and make sure the basics work. Can we try this?

import os
import submitit
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment



class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


NUM_GPUS_PER_NODE = 8
NUM_NODES = 2


def train():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        gpus=NUM_GPUS_PER_NODE,
        num_nodes=NUM_NODES,
        accelerator="ddp",
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
    )
    assert trainer.world_size == 16
    assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def main():
    logdir = "debug_log_dir"
    os.makedirs(logdir, exist_ok=True)

    # executor is the submission interface (logs are dumped in the folder)
    executor = submitit.AutoExecutor(folder=logdir)
    executor.update_parameters(
        mem_gb=80 * NUM_GPUS_PER_NODE,
        timeout_min=1500,
        slurm_partition={NAME},
        gpus_per_node=NUM_GPUS_PER_NODE,
        tasks_per_node=NUM_GPUS_PER_NODE,
        cpus_per_task=10,
        nodes=NUM_NODES,
        slurm_constraint="volta32gb",
    )
    job = executor.submit(train)


if __name__ == "__main__":
    main()

You may have to fill in a detail here an there, I have not run this myself.

@awaelchli awaelchli added the priority: 1 Medium priority task label Aug 19, 2021
@MartaTintore
Copy link
Author

MartaTintore commented Aug 19, 2021

Hi,

Thank you for the reply.

By running the basics I realized the problem was with the flags with which I was launching the job.
args.ngpus is the flag I created myself to run the command on the terminal and set the number of gpus per node I request from the cluster.
However, Pytorch Lightning requires the number of gpus as well by passing: --gpus in the same command.

Conclusion: I need to specify both (i) --gpus and (ii) --ngpus to make sure that (i) PL knows which GPUs to use and (ii) I request the same number of GPUs from the cluster.
Maybe it would be nice to have a warning/message when the flag --gpus is not specified? It didnt complain about GPUs not being specified for PL...

Thanks again for your help!

@awaelchli
Copy link
Contributor

awaelchli commented Aug 19, 2021

@MartaTintore happy to hear that you were able to find the problem.

Yes, that's good feedback. Lightning could do a better job at validating that the Lightning parameters match the world size set by the slurm environment.

However, to this point:

Maybe it would be nice to have a warning/message when the flag --gpus is not specified? It didnt complain about GPUs not being specified for PL...

Unfortunately I believe we cannot know this! It's perfectly valid to run one or multiple training processes purely on CPU, but Lightning cannot know that unless the user provides this information somehow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working environment: slurm help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants