Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

false device inconsistent runtime error #602

Closed
cnut1648 opened this issue Nov 3, 2021 · 3 comments · Fixed by #606
Closed

false device inconsistent runtime error #602

cnut1648 opened this issue Nov 3, 2021 · 3 comments · Fixed by #606
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@cnut1648
Copy link

cnut1648 commented Nov 3, 2021

🐛 Bug

When using weighted AUC even if the input tensors and metric are in the same device.

To Reproduce

auc = AUROC(num_classes=7, average='weighted')
x = torch.randn(13, 7).to('cuda')
y = torch.randint(7, 13).to('cuda')
auc = auc.to(x.device)
# runtime error
auc(x, y)

would result in

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Based on this erorr message, I assume this because in the intermediate step some cpu tensors are created? But I am not too sure.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.1
  • OS (e.g., Linux): Linux Ubuntu
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8
  • CUDA/cuDNN version: 10.1
  • GPU models and configuration:
  • Any other relevant information:
    Torchmetrics version 0.6.0
    Torch Lightning 1.4.9
@cnut1648 cnut1648 added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 3, 2021
@Borda
Copy link
Member

Borda commented Nov 3, 2021

it seems as common confusion, please see #531

@Borda Borda closed this as completed Nov 3, 2021
@Borda Borda changed the title [BUG] false device inconsistent runtime error using weighted AUC false device inconsistent runtime error Nov 3, 2021
@cnut1648
Copy link
Author

cnut1648 commented Nov 3, 2021

Thanks for the reply. I am using PL module and actually I think this is related to AUROC with average set to weight.
Here is my code

class FineTuneModule(LightningModule):
    def __init__(
            self,
            arch: str,
            **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        # metrics
        mc = MetricCollection({
            "accuracy": Accuracy(threshold=0.0),
            "recall": Recall(threshold=0.0, num_classes=7, average='macro'),
            "precision": Precision(threshold=0.0, num_classes=7, average='macro'),
            "f1": F1(threshold=0.0, num_classes=7, average='macro'),
            "macro_auc": AUROC(num_classes=7, average='macro'),
            "weighted_auc": AUROC(num_classes=7, average='weighted')
        })
        self.metrics: ModuleDict[str, MetricCollection] = ModuleDict({
            f"{phase}_metric": mc.clone()
            for phase in ["train", "valid", "test"]
        })

and in the *_step_end function I called

# phase is train or valid or test
metrics = self.metrics[f"{phase}_metric"]
metrics(output['prob'], output['label'])

And I think

  1. all metrics in the are all in the same device (i.e. for for metric_name, metric in metrics.items(): print(f"{metric_name}: : {metric.device}") have cuda for all metrics
  2. all metrics can compute correctly except weighted_auc

Sorry I wasn't make this clear.

@SkafteNicki SkafteNicki reopened this Nov 4, 2021
@SkafteNicki
Copy link
Member

Reopening as I can confirm there seems to be a bug in the code here.
It happens in the case when some labels are not present in target tensor it seems, where the code will evaluate these lines:
https://github.com/PyTorchLightning/metrics/blob/fef83a028880f2ad3e0c265b3e5bb8a184805798/torchmetrics/functional/classification/auroc.py#L138-L139
That tensor probably need to be initialized on the correct device. Going to send a PR :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants