Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_step hangs after one iteration when on multiple GPUs #3730

Closed
vegovs opened this issue Sep 29, 2020 · 6 comments · Fixed by #3819
Closed

test_step hangs after one iteration when on multiple GPUs #3730

vegovs opened this issue Sep 29, 2020 · 6 comments · Fixed by #3819
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on
Milestone

Comments

@vegovs
Copy link

vegovs commented Sep 29, 2020

🐛 Bug

When running the same code on a computer with 1 gpu, test_step runs as normal and logs what it should.
How ever on a node with 4 gpus, it hangs after 1 iteration!

Code sample

 images, masks = batch["image"], batch["mask"]
        if images.shape[1] != self.hparams.n_channels:
            raise AssertionError(
                f"Network has been defined with {self.n_channels} input channels, "
                f"but loaded images have {images.shape[1]} channels. Please check that "
                "the images are loaded correctly."
            )

        masks = (
            masks.type(torch.float32)
            if self.hparams.n_classes == 1
            else masks.type(torch.long)
        )

        masks_pred = self(images)  # Forward pass
        loss = self.loss_funciton(masks_pred, masks)
        result = pl.EvalResult(loss, checkpoint_on=loss)
        result.log("test_loss", loss, on_step=True, on_epoch=True, sync_dist=True)
        rand_idx = randint(0, self.hparams.batch_size - 1)
        onehot = torch.sigmoid(masks_pred[rand_idx]) > 0.5
        for tag, value in self.named_parameters():
            tag = tag.replace(".", "/")
            self.logger.experiment.add_histogram(tag, value, self.current_epoch)
        mask_grid = torchvision.utils.make_grid([masks[rand_idx], onehot], nrow=2)
        self.logger.experiment.add_image(
            "TEST - Target vs Predicted", mask_grid, self.current_epoch
        )
        alpha = 0.5
        image_grid = torchvision.utils.make_grid(
            [
                images[rand_idx],
                torch.clamp(
                    kornia.enhance.add_weighted(
                        src1=images[rand_idx],
                        alpha=1.0,
                        src2=onehot,
                        beta=alpha,
                        gamma=0.0,
                    ),
                    max=1.0,
                ),
            ]
        )
        self.logger.experiment.add_image(
            "TEST - Image vs Predicted", image_grid, self.current_epoch
        )
        pred = (torch.sigmoid(masks_pred) > 0.5).float()
        f1 = f1_score(pred, masks, self.hparams.n_classes + 1)
        rec = recall(pred, masks, self.hparams.n_classes + 1)
        pres = precision(pred, masks, self.hparams.n_classes + 1)
        result.log("test_f1", f1, on_epoch=True)
        result.log("test_recall", rec, on_epoch=True)
        result.log("test_precision", pres, on_epoch=True)

        return result

Expected behavior

I expect it to finish the testing-epoch.

Environment

Environment 1
CUDA:

  • GPU:
  • GeForce RTX 2070 SUPER
  • available: True
  • version: 10.2
    Packages:
  • numpy: 1.19.2
  • pyTorch_debug: False
  • pyTorch_version: 1.6.0
  • pytorch-lightning: 0.9.0
  • tqdm: 4.49.0
    System:
  • OS: Linux
  • architecture:
  • 64bit
  • ELF
  • processor: x86_64
  • python: 3.6.9
  • version: Rename ptl to pl #52~18.04.1-Ubuntu SMP Thu Sep 10 12:50:22 UTC 2020

Environment 2

  • CUDA:
    • GPU:
      • GeForce RTX 2080 Ti
      • GeForce RTX 2080 Ti
      • GeForce RTX 2080 Ti
      • GeForce RTX 2080 Ti
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.19.1
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0
    • pytorch-lightning: 0.9.0
    • tqdm: 4.49.0
  • System:
@vegovs vegovs added bug Something isn't working help wanted Open to be worked on labels Sep 29, 2020
@awaelchli
Copy link
Contributor

Can you post which trainer settings you are using?

@vegovs
Copy link
Author

vegovs commented Sep 30, 2020

    def training_step(self, batch, batch_idx):
        images, masks = batch["image"], batch["mask"]
        if images.shape[1] != self.hparams.n_channels:
            raise AssertionError(
                f"Network has been defined with {self.hparams.n_channels} input channels, "
                f"but loaded images have {images.shape[1]} channels. Please check that "
                "the images are loaded correctly."
            )

        masks = (
            masks.type(torch.float32)
            if self.hparams.n_classes == 1
            else masks.type(torch.long)
        )

        masks_pred = self(images)  # Forward pass
        loss = self.loss_funciton(masks_pred, masks)
        result = pl.TrainResult(minimize=loss)
        result.log("train_loss", loss, sync_dist=True)
        if batch_idx == 0:
            self.logg_images(images, masks, masks_pred, "TRAIN")
        pred = (torch.sigmoid(masks_pred) > 0.5).float()
        f1 = f1_score(pred, masks, self.hparams.n_classes + 1)
        rec = recall(pred, masks, self.hparams.n_classes + 1)
        pres = precision(pred, masks, self.hparams.n_classes + 1)
        result.log("train_f1", f1, on_epoch=True)
        result.log("train_recall", rec, on_epoch=True)
        result.log("train_precision", pres, on_epoch=True)

        return result

    def validation_step(self, batch, batch_idx):
        images, masks = batch["image"], batch["mask"]
        if images.shape[1] != self.hparams.n_channels:
            raise AssertionError(
                f"Network has been defined with {self.n_channels} input channels, "
                f"but loaded images have {images.shape[1]} channels. Please check that "
                "the images are loaded correctly."
            )

        masks = (
            masks.type(torch.float32)
            if self.hparams.n_classes == 1
            else masks.type(torch.long)
        )

        masks_pred = self(images)  # Forward pass
        loss = self.loss_funciton(masks_pred, masks)
        result = pl.EvalResult(loss, checkpoint_on=loss)
        result.log("val_loss", loss, sync_dist=True)
        if batch_idx == 0:
            self.logg_images(images, masks, masks_pred, "VAL")
        pred = (torch.sigmoid(masks_pred) > 0.5).float()
        f1 = f1_score(pred, masks, self.hparams.n_classes + 1)
        rec = recall(pred, masks, self.hparams.n_classes + 1)
        pres = precision(pred, masks, self.hparams.n_classes + 1)
        result.log("val_f1", f1, on_epoch=True)
        result.log("val_recall", rec, on_epoch=True)
        result.log("val_precision", pres, on_epoch=True)

        return result

    def test_step(self, batch, batch_idx):
        images, masks = batch["image"], batch["mask"]
        if images.shape[1] != self.hparams.n_channels:
            raise AssertionError(
                f"Network has been defined with {self.n_channels} input channels, "
                f"but loaded images have {images.shape[1]} channels. Please check that "
                "the images are loaded correctly."
            )

        masks = (
            masks.type(torch.float32)
            if self.hparams.n_classes == 1
            else masks.type(torch.long)
        )

        masks_pred = self(images)  # Forward pass
        loss = self.loss_funciton(masks_pred, masks)
        result = pl.EvalResult(loss, checkpoint_on=loss)
        result.log("test_loss", loss, on_step=True, on_epoch=True, sync_dist=True)
        self.logg_images(images, masks, masks_pred, "TEST")
        pred = (torch.sigmoid(masks_pred) > 0.5).float()
        f1 = f1_score(pred, masks, self.hparams.n_classes + 1)
        rec = recall(pred, masks, self.hparams.n_classes + 1)
        pres = precision(pred, masks, self.hparams.n_classes + 1)
        result.log("test_f1", f1, on_epoch=True)
        result.log("test_recall", rec, on_epoch=True)
        result.log("test_precision", pres, on_epoch=True)

        return result

@awaelchli
Copy link
Contributor

What arguments do you pass to Trainer(...)
Do you use distributed_backend=ddp?

@vegovs
Copy link
Author

vegovs commented Sep 30, 2020

    try:
        trainer = Trainer.from_argparse_args(
            args,
            gpus=-1,
            precision=16,
            distributed_backend="ddp",
            callbacks=[lr_monitor],
            early_stop_callback=early_stopping,
            accumulate_grad_batches=1
            if not os.getenv("ACC_GRAD")
            else int(os.getenv("ACC_GRAD")),
            gradient_clip_val=0.0
            if not os.getenv("GRAD_CLIP")
            else float(os.getenv("GRAD_CLIP")),
            max_epochs=1000 if not os.getenv("EPOCHS") else int(os.getenv("EPOCHS")),
            default_root_dir=os.getcwd()
            if not os.getenv("DIR_ROOT_DIR")
            else os.getenv("DIR_ROOT_DIR"),
        )
        trainer.fit(model)
        trainer.test(model)
    except KeyboardInterrupt:
        torch.save(model.state_dict(), "INTERRUPTED.pth")
        logging.info("Saved interrupt")
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)

@vegovs
Copy link
Author

vegovs commented Sep 30, 2020

Is using DDP the issue? I used DDP on the environment with one GPU as well.

@awaelchli
Copy link
Contributor

awaelchli commented Sep 30, 2020

yes, unfortunately it looks like so. calling trainer.fit then trainer.test currently does not work with DDP due to the fact that it is launching in a separate script. We haven't found a good solution for this yet. For the one gpu case, it should work though. A solution is to switch to ddp_spawn, but it is not ideal. then all your classes need to be pickleable.
or you may move trainer.test to a separate script and call it independently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on
Projects
None yet
3 participants