diff --git a/pl_bolts/callbacks/ssl_online.py b/pl_bolts/callbacks/ssl_online.py index 31998d7875..5e2f01ad57 100644 --- a/pl_bolts/callbacks/ssl_online.py +++ b/pl_bolts/callbacks/ssl_online.py @@ -2,10 +2,10 @@ import torch from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.metrics.functional import accuracy from torch import device, Tensor from torch.nn import functional as F from torch.optim import Optimizer +from torchmetrics.functional import accuracy class SSLOnlineEvaluator(Callback): # pragma: no cover diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index b262709f1c..21178f5e17 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -2,12 +2,12 @@ import pytorch_lightning as pl import torch -from pytorch_lightning.metrics.functional import accuracy from torch import nn from torch.nn import functional as F from torch.nn.functional import softmax from torch.optim import Adam from torch.optim.optimizer import Optimizer +from torchmetrics.functional import accuracy class LogisticRegression(pl.LightningModule): diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index cd8db83f61..fbac4e70f3 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -2,8 +2,8 @@ import pytorch_lightning as pl import torch -from pytorch_lightning.metrics import Accuracy from torch.nn import functional as F +from torchmetrics import Accuracy from pl_bolts.models.self_supervised import SSLEvaluator diff --git a/requirements.txt b/requirements.txt index dbb435e71b..55acd2a55e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ torch>=1.6 +torchmetrics>=0.2.0 pytorch-lightning>=1.1.1 \ No newline at end of file