Skip to content

Running ray with pytorch lightning in slurm job causes falure with error "ValueError: signal only works in main thread" #3651

Closed
@rashindrie

Description

@rashindrie

🐛 Bug

I followed the instructions at https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html to integrate ray with pytorch lightning. However, when I submitted a slurm job to run the tuning I get the following error:
ValueError: signal only works in main thread

I submitted the same to ray project at ray/issues/10995 and I was suggested a hack to fix the issue.

Could we look for a way to disable the SLURM detection in pytorch lightning itself so that external parties do not have to hack its way around it?

To Reproduce

Steps to reproduce the behavior:

  1. Set up ray by following instructions at here.
  2. Submit a slurm job to run the tuning
  3. See error
ray.tune.error.TuneError: Trial raised an exception. Traceback:
ray::ImplicitFunc.train() (pid=4432, ip=172.26.92.190)
  File "/home/user/.local/lib/python3.7/site-packages/ray/tune/function_runner.py", line 227, in run
    self._entrypoint()
  File "/home/user/.local/lib/python3.7/site-packages/ray/tune/function_runner.py", line 290, in entrypoint
    self._status_reporter.get_checkpoint())
  File "/home/user/.local/lib/python3.7/site-packages/ray/tune/function_runner.py", line 497, in _trainable_func
    output = train_func(config)
  File "tune.py", line 261, in train_run
    trainer.fit(model)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1073, in fit
    results = self.accelerator_backend.train(model)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_backend.py", line 51, in train
    results = self.trainer.run_pretrain_routine(model)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1184, in run_pretrain_routine
    self.register_slurm_signal_handlers()
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 240, in register_slurm_signal_handlers
    signal.signal(signal.SIGUSR1, self.sig_handler)
  File "/usr/local/easybuild-2019/easybuild/software/mpi/gcc/8.3.0/openmpi/3.1.4/python/3.7.4/lib/python3.7/signal.py", line 47, in signal
    handler = _signal.signal(_enum_to_int(signalnum), _enum_to_int(handler))
ValueError: signal only works in main thread

Code sample

slurm script

#!/bin/bash

#SBATCH --gres=gpu:4
#SBATCH --nodes=1
#SBATCH --ntasks=1


# load necessary modules #
module purge
module load scikit-learn/0.21.3-python-3.7.4
module load python/3.7.4

python -u tune.py &> "tune_output.txt"

tune.py

class CustomDataSet(Dataset):
    def __init__(self, csv_file, img_dir, transform):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_loc = join(self.img_dir, self.data.name[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        label = self.data.label[idx]

        return tensor_image, label


class ExperimentAE(pl.LightningModule):
    def __init__(self,
                 params: dict,
                 **kwargs) -> None:
        super(ExperimentAE, self).__init__()

        self.params = params
        self.model = SegNet()

    def forward(self, z):
        return self.decoder(z)

    def _run_step(self, x):
        x_hat, z = self.model(x)
        return x_hat, z

    def generate(self, x):
        return self._run_step(x)[0]

    def step(self, batch, batch_idx):
        x, y = batch
        self.curr_device = x.device
        x_hat, z = self._run_step(x)
        loss = F.mse_loss(x_hat, x, reduction='mean')
        return {"loss": loss}

    def training_step(self, batch, batch_idx):
        train_loss = self.step(batch, batch_idx)
        logs = {"ptl/train_loss": train_loss}
        return {"loss": train_loss, "log": logs}

    def validation_step(self, batch, batch_idx):
        val_loss = self.step(batch, batch_idx)
        return {"val_loss": val_loss["loss"]}

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}
        
    def configure_optimizers(self):
        optims = []

        optimizer = optim.Adam(self.model.parameters(),
                               lr=self.params['lr'],
                               # weight_decay=self.params['weight_decay']
                               )
       optims.append(optimizer)
       return optims

    def train_dataloader(self):
        transform = self.data_transforms(train=True)
        img_dir = "/data/brca/"
        train_csv = "/original/train.csv"

        dataset = CustomDataSet(train_csv, img_dir, transform=transform)
        loader = DataLoader(dataset, shuffle=True, batch_size=self.params['batch_size'], num_workers=4)

        self.num_train_imgs = dataset.__len__()

        return loader

    def val_dataloader(self):
        transform = self.data_transforms(train=False)

        img_dir = "/data/brca/"
        val_csv = "/original/validation.csv"

        val_dataset = CustomDataSet(val_csv, img_dir, transform=transform)
        self.valid_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=self.params['batch_size'],
                                           num_workers=4)

        self.num_val_imgs = self.valid_dataloader.__len__()

        return self.valid_dataloader

    def test_dataloader(self):
        pass

    def data_transforms(self, train=True):
        if train:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
            ])
            return transform_train

        else:
            transform_val = transforms.Compose([
                transforms.ToTensor(),
            ])
            return transform_val


def train_run(config_params, num_epochs=10, num_gpus=1):
    model = ExperimentAE(params=config_params)

    trainer = Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        logger=TensorBoardLogger(
            save_dir=tune.get_trial_dir(), name="", version="."),
        progress_bar_refresh_rate=0,
        callbacks=[
            TuneReportCallback(
                {
                    "loss": "val_loss",
                },
                on="validation_end")
        ])

    trainer.fit(model)


def tune_run(num_samples=20, num_epochs=10, gpus_per_trial=1):
    tune_config = {
        "lr": tune.loguniform(1e-4, 1e-5, 1e-3),
        "batch_size": tune.choice([2, 4, 8]),
        # 'weight_decay': 0.0,
        'scheduler_gamma': tune.choice([1, 0.95, 0.9, 0.85, 0.6]),
    }

    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=10,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        parameter_columns=["lr", "batch_size"],
        metric_columns=["loss", "training_iteration"]
    )
    tune.run(
        partial(
            train_run,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        config=tune_config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="tune_segnet_v1"
    )


if __name__ == "__main__":
    tune_run(num_samples=20, num_epochs=10, gpus_per_trial=1)

Expected behavior

The ray tune program to run properly in a slurm environment.

Environment

  • CUDA:
    • GPU:
      • Tesla V100-SXM2-16GB
      • Tesla V100-SXM2-16GB
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.17.3
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0
    • pytorch-lightning: 0.9.0
    • tqdm: 4.46.0
      • ray: 0.8.7
      • tensorflow: 2.1.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.7.4
    • version: Proposal for help #1 SMP Tue May 26 15:05:43 EDT 2020

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    3rd partyRelated to a 3rd-partybugSomething isn't workinghelp wantedOpen to be worked onwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions