Skip to content

Commit

Permalink
Fixing torchmetrics deprecations
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Feb 1, 2022
1 parent d79cd18 commit 70f492d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import Accuracy, FBeta, JaccardIndex, MetricCollection
from torchmetrics import Accuracy, FBetaScore, JaccardIndex, MetricCollection

from ..datasets.utils import unbind_samples
from . import utils
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, **kwargs: Any) -> None:
num_classes=self.hparams["num_classes"], average="macro"
),
"JaccardIndex": JaccardIndex(num_classes=self.hparams["num_classes"]),
"F1Score": FBeta(
"F1Score": FBetaScore(
num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
),
},
Expand Down Expand Up @@ -292,7 +292,7 @@ def __init__(self, **kwargs: Any) -> None:
average="macro",
multiclass=False,
),
"F1Score": FBeta(
"F1Score": FBetaScore(
num_classes=self.hparams["num_classes"],
beta=1.0,
average="micro",
Expand Down

0 comments on commit 70f492d

Please sign in to comment.