Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix backward compatability for older versions of lightning #182

Merged
merged 14 commits into from
Apr 20, 2021
122 changes: 119 additions & 3 deletions integrations/test_metric_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import torch
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from torch import tensor
from torch.utils.data import DataLoader

from integrations.lightning_models import BoringModel
from torchmetrics import Metric
from integrations.lightning_models import BoringModel, RandomDataset
from torchmetrics import Accuracy, AveragePrecision, Metric


class SumMetric(Metric):
Expand Down Expand Up @@ -80,6 +83,119 @@ def training_epoch_end(self, outs):
trainer.fit(model)


def test_metrics_reset(tmpdir):
"""Tests that metrics are reset correctly after the end of the train/val/test epoch."""

class TestModel(LightningModule):

def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 1)

for stage in ['train', 'val', 'test']:
acc = Accuracy()
acc.reset = mock.Mock(side_effect=acc.reset)
ap = AveragePrecision(num_classes=1, pos_label=1)
ap.reset = mock.Mock(side_effect=ap.reset)
self.add_module(f"acc_{stage}", acc)
self.add_module(f"ap_{stage}", ap)

def forward(self, x):
return self.layer(x)

def _step(self, stage, batch):
labels = (batch.detach().sum(1) > 0).float() # Fake some targets
logits = self.forward(batch)
loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1))
probs = torch.sigmoid(logits.detach())
self.log(f"loss/{stage}", loss)

acc = self._modules[f"acc_{stage}"]
ap = self._modules[f"ap_{stage}"]

labels_int = labels.to(torch.long)
acc(probs.flatten(), labels_int)
ap(probs.flatten(), labels_int)

# Metric.forward calls reset so reset the mocks here
acc.reset.reset_mock()
ap.reset.reset_mock()

self.log(f"{stage}/accuracy", acc)
self.log(f"{stage}/ap", ap)

return loss

def training_step(self, batch, batch_idx, *args, **kwargs):
return self._step('train', batch)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return self._step('val', batch)

def test_step(self, batch, batch_idx, *args, **kwargs):
return self._step('test', batch)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def test_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def _assert_epoch_end(self, stage):
acc = self._modules[f"acc_{stage}"]
ap = self._modules[f"ap_{stage}"]

acc.reset.asset_not_called()
ap.reset.assert_not_called()

def on_train_epoch_end(self, outputs):
self._assert_epoch_end('train')

def on_validation_epoch_end(self, outputs):
self._assert_epoch_end('val')

def on_test_epoch_end(self, outputs):
self._assert_epoch_end('test')

def _assert_called(model, stage):
acc = model._modules[f"acc_{stage}"]
ap = model._modules[f"ap_{stage}"]

acc.reset.assert_called_once()
acc.reset.reset_mock()

ap.reset.assert_called_once()
ap.reset.reset_mock()

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
progress_bar_refresh_rate=0,
)

trainer.fit(model)
_assert_called(model, 'train')
_assert_called(model, 'val')

trainer.validate(model)
_assert_called(model, 'val')

trainer.test(model)
_assert_called(model, 'test')


# todo: reconsider if it make sense to keep here
# def test_metric_lightning_log(tmpdir):
# """ Test logging a metric object and that the metric state gets reset after each epoch."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|buil
known_first_party = [
"torchmetrics",
"tests",
"integrations",
]
skip_glob = []
profile = "black"
Expand Down
7 changes: 5 additions & 2 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
from torchmetrics.utilities.imports import _LIGHTNING_GREATER_THAN_1_3, _TORCH_LOWER_1_6

seed_all(42)

Expand Down Expand Up @@ -101,7 +101,10 @@ def test_reset_compute():
a.update(tensor(5))
assert a.compute() == 5
a.reset()
assert a.compute() == 0
if _LIGHTNING_GREATER_THAN_1_3:
assert a.compute() == 0
else:
assert a.compute() == 5


def test_update():
Expand Down
28 changes: 9 additions & 19 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_accuracy_differentiability(self, preds, target, subset_accuracy):
(_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True),
(_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True),
(_av_preds_ml, _av_target_ml, 5 / 8, None, False),
(_av_preds_ml, _av_target_ml, 0, None, True)
(_av_preds_ml, _av_target_ml, 0, None, True),
],
)
def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy):
Expand Down Expand Up @@ -209,18 +209,10 @@ def test_topk_accuracy_wrong_input_types(preds, target):
("micro", None, None, _input_mcls_prob, NUM_CLASSES, None, 0.5),
("micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES, None, 0.5),
(None, None, None, _input_mcls_prob, None, 0, 0.5),
(None, None, None, _input_mcls_prob, None, None, 1.5)
(None, None, None, _input_mcls_prob, None, None, 1.5),
],
)
def test_wrong_params(
average,
mdmc_average,
num_classes,
inputs,
ignore_index,
top_k,
threshold
):
def test_wrong_params(average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold):
preds, target = inputs.preds, inputs.target

with pytest.raises(ValueError):
Expand Down Expand Up @@ -250,14 +242,12 @@ def test_wrong_params(

@pytest.mark.parametrize(
"preds_mc, target_mc, preds_ml, target_ml",
[
(
tensor([0, 1, 1, 1]),
tensor([2, 2, 1, 1]),
tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]),
tensor([[1, 0, 1, 1], [0, 0, 1, 0]]),
)
],
[(
tensor([0, 1, 1, 1]),
tensor([2, 2, 1, 1]),
tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]),
tensor([[1, 0, 1, 1], [0, 0, 1, 0]]),
)],
)
def test_different_modes(preds_mc, target_mc, preds_ml, target_ml):
acc = Accuracy()
Expand Down
5 changes: 1 addition & 4 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def update(self, preds: Tensor, target: Tensor):
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""

""" returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """
mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass)

Expand All @@ -234,9 +233,7 @@ def update(self, preds: Tensor, target: Tensor):
self.subset_accuracy = False

if self.subset_accuracy:
correct, total = _subset_accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k,
)
correct, total = _subset_accuracy_update(preds, target, threshold=self.threshold, top_k=self.top_k)
self.correct += correct
self.total += total
else:
Expand Down
9 changes: 6 additions & 3 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _mode(
threshold: float,
top_k: Optional[int],
num_classes: Optional[int],
multiclass: Optional[bool]
multiclass: Optional[bool],
) -> DataType:
mode = _check_classification_inputs(
preds, target, threshold=threshold, top_k=top_k, num_classes=num_classes, multiclass=multiclass
Expand All @@ -49,7 +49,7 @@ def _accuracy_update(
top_k: Optional[int],
multiclass: Optional[bool],
ignore_index: Optional[int],
mode: DataType
mode: DataType,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
Expand Down Expand Up @@ -89,7 +89,10 @@ def _accuracy_compute(


def _subset_accuracy_update(
preds: Tensor, target: Tensor, threshold: float, top_k: Optional[int],
preds: Tensor,
target: Tensor,
threshold: float,
top_k: Optional[int],
) -> Tuple[Tensor, Tensor]:

preds, target = _input_squeeze(preds, target)
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.imports import _LIGHTNING_GREATER_THAN_1_3


class Metric(nn.Module, ABC):
Expand Down Expand Up @@ -256,7 +257,10 @@ def reset(self):
"""
This method automatically resets the metric state variables to their default value.
"""
self._computed = None
# lower lightning versions requires this implicitly to log metric objects correctly
# in self.log
if _LIGHTNING_GREATER_THAN_1_3:
self._computed = None
Borda marked this conversation as resolved.
Show resolved Hide resolved

for attr, default in self._defaults.items():
current_val = getattr(self, attr)
Expand Down
15 changes: 11 additions & 4 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import operator
from importlib import import_module
from importlib.util import find_spec
from typing import Optional

from packaging.version import Version
from pkg_resources import DistributionNotFound
from pkg_resources import DistributionNotFound, get_distribution


def _module_available(module_path: str) -> bool:
Expand All @@ -39,7 +40,7 @@ def _module_available(module_path: str) -> bool:
return False


def _compare_version(package: str, op, version) -> bool:
def _compare_version(package: str, op, version) -> Optional[bool]:
"""
Compare package version with some requirements

Expand All @@ -49,10 +50,15 @@ def _compare_version(package: str, op, version) -> bool:
"""
try:
pkg = import_module(package)
pkg_version = pkg.__version__
except (ModuleNotFoundError, DistributionNotFound):
return False
return None
except ImportError:
# catches cyclic imports - the case with integrated libs
# see: https://stackoverflow.com/a/32965521
pkg_version = get_distribution(package).version
try:
pkg_version = Version(pkg.__version__)
pkg_version = Version(pkg_version)
except TypeError:
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
Expand All @@ -64,3 +70,4 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_LIGHTNING_GREATER_THAN_1_3 = _compare_version("pytorch_lightning", operator.gt, "1.3.0")