Skip to content

TensorMetric not updated to cuda device #2274

Closed
@tayden

Description

@tayden

🐛 Bug

When tensors located on the GPU are passed through a TensorMetric forward method, they are transfered to the CPU by this method. It seems that the 'cpu' device is not updated when training on a gpu. The cpu device seems to be hardcoded here

To Reproduce

Steps to reproduce the behavior:

import torch
import torch.nn.functional as F
from pytorch_lightning.metrics import TensorMetric


class FocalTverskyMetric(TensorMetric):
    def __init__(self, alpha: float, beta: float, gamma: float, smooth=1e-8):
        super().__init__("FocalTversky")
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth

        # This line works as a workaround
        # self._device = torch.device('cuda')

    def _tversky_index_c(self, p: torch.Tensor, g: torch.Tensor):
        c = p.shape[1]
        p = p.permute(0, 2, 3, 1).reshape((-1, c))
        g = F.one_hot(g.flatten().long(), c)

        tp = torch.sum(torch.mul(p, g), dim=0)
        fn = torch.sum(torch.mul(1. - p, g), dim=0)
        fp = torch.sum(torch.mul(p, 1. - g), dim=0)
        return (tp + self.smooth) / (tp + self.alpha * fn + self.beta * fp + self.smooth)

    def forward(self, x, y):
        ti = self._tversky_index_c(x, y)
        res = (1 - ti).pow(1 / self.gamma)
        return torch.sum(res, dim=0)


if __name__ == '__main__':
    metric = FocalTverskyMetric(alpha=0.5, beta=0.5, gamma=1.)

    preds = torch.Tensor([[[[1.]], [[0.]]], [[[1.]], [[0.]]], [[[0.]], [[1.]]]]).cuda()
    assert preds.is_cuda  # Passes
    labels = torch.Tensor([[[1]], [[1]], [[0]]]).cuda()
    assert labels.is_cuda  # Passes

    loss = metric(preds, labels)
    assert loss.is_cuda  # Fails

When training in DDP mode with 16-bit precision, these metrics throws the stack trace below. This disappears in 32-bit mode since it's the amp grad_scaler asserting that the loss value is a cuda tensor.

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0,1,2,3]
Using 16bit precision.
Using 16bit precision.
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/4
Using 16bit precision.
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/4
Using 16bit precision.
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/4
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/4
----------------------------------------------------------------------------------------------------
distributed_backend=ddp
All DDP processes registered. Starting ddp with 4 processes
----------------------------------------------------------------------------------------------------

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | DeepLabV3          | 60 M
1 | calc_loss | FocalTverskyMetric | 0
2 | calc_iou  | IoU                | 0
Epoch 1:   0% 0/493 [00:00<?, ?it/s] Traceback (most recent call last):
  File "/opt/code/deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/code/deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 895, in fit
    self.ddp_train(task, model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Traceback (most recent call last):
  File "/opt/code/deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/code/deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 895, in fit
    self.ddp_train(task, model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Traceback (most recent call last):
  File "deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 910, in fit
    self.spawn_ddp_children(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 442, in spawn_ddp_children
    self.ddp_train(local_rank, model, is_master=True)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Traceback (most recent call last):
  File "/opt/code/deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/code/deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 895, in fit
    self.ddp_train(task, model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Exception ignored in: <function tqdm.__del__ at 0x7ff26f244a70>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1086, in __del__
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1293, in close
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1471, in display
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1089, in __repr__
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1433, in format_dict
TypeError: cannot unpack non-iterable NoneType object

Expected behavior

The TensorMetric should have self._device updated to equal the current Trainer device during initialization.

Environment

  • CUDA:
    • GPU:
      • GeForce GTX 1050
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0.dev20200618
    • pytorch-lightning: 0.8.0
    • tensorboard: 2.1.1
    • tqdm: 4.46.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.3
    • version: Unable to import trainer #41-Ubuntu SMP Wed Jun 3 18:57:02 UTC 2020

The bug is also present on AWS p3.8xlarge instances using the same environment but with 4 Nvidia Tesla V100s

Metadata

Metadata

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions