Skip to content

Commit

Permalink
Fix backward compatability for older versions of lightning (#182)
Browse files Browse the repository at this point in the history
* imports

* add integration test

* fix

* format

* fix2

* format

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
SkafteNicki and Borda authored Apr 20, 2021
1 parent bc5f9a9 commit 7115789
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 37 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed when `_stable_1d_sort` to work when n >= N ([PL^6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177))
- Fixed `_computed` attribute not being correctly reset ([#147](https://github.com/PyTorchLightning/metrics/pull/147))
- Fixed to blau score ([#165](https://github.com/PyTorchLightning/metrics/pull/165))

- Fixed backwards compatability for logging with older version of pytorch lightning ([#182](https://github.com/PyTorchLightning/metrics/pull/182))

## [0.2.0] - 2021-03-12

Expand Down
128 changes: 125 additions & 3 deletions integrations/test_metric_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from torch import tensor
from torch.utils.data import DataLoader

from integrations.lightning_models import BoringModel
from torchmetrics import Metric
from integrations.lightning_models import BoringModel, RandomDataset
from torchmetrics import Accuracy, AveragePrecision, Metric
from torchmetrics.utilities.imports import _LIGHTNING_GREATER_EQUAL_1_3


class SumMetric(Metric):
Expand Down Expand Up @@ -80,6 +85,123 @@ def training_epoch_end(self, outs):
trainer.fit(model)


@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason='test requires lightning v1.3 or higher')
def test_metrics_reset(tmpdir):
"""Tests that metrics are reset correctly after the end of the train/val/test epoch.
Taken from:
https://github.com/PyTorchLightning/pytorch-lightning/pull/7055
"""

class TestModel(LightningModule):

def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 1)

for stage in ['train', 'val', 'test']:
acc = Accuracy()
acc.reset = mock.Mock(side_effect=acc.reset)
ap = AveragePrecision(num_classes=1, pos_label=1)
ap.reset = mock.Mock(side_effect=ap.reset)
self.add_module(f"acc_{stage}", acc)
self.add_module(f"ap_{stage}", ap)

def forward(self, x):
return self.layer(x)

def _step(self, stage, batch):
labels = (batch.detach().sum(1) > 0).float() # Fake some targets
logits = self.forward(batch)
loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1))
probs = torch.sigmoid(logits.detach())
self.log(f"loss/{stage}", loss)

acc = self._modules[f"acc_{stage}"]
ap = self._modules[f"ap_{stage}"]

labels_int = labels.to(torch.long)
acc(probs.flatten(), labels_int)
ap(probs.flatten(), labels_int)

# Metric.forward calls reset so reset the mocks here
acc.reset.reset_mock()
ap.reset.reset_mock()

self.log(f"{stage}/accuracy", acc)
self.log(f"{stage}/ap", ap)

return loss

def training_step(self, batch, batch_idx, *args, **kwargs):
return self._step('train', batch)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return self._step('val', batch)

def test_step(self, batch, batch_idx, *args, **kwargs):
return self._step('test', batch)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def test_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def _assert_epoch_end(self, stage):
acc = self._modules[f"acc_{stage}"]
ap = self._modules[f"ap_{stage}"]

acc.reset.asset_not_called()
ap.reset.assert_not_called()

def train_epoch_end(self, outputs):
self._assert_epoch_end('train')

def validation_epoch_end(self, outputs):
self._assert_epoch_end('val')

def test_epoch_end(self, outputs):
self._assert_epoch_end('test')

def _assert_called(model, stage):
acc = model._modules[f"acc_{stage}"]
ap = model._modules[f"ap_{stage}"]

acc.reset.assert_called_once()
acc.reset.reset_mock()

ap.reset.assert_called_once()
ap.reset.reset_mock()

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
progress_bar_refresh_rate=0,
)

trainer.fit(model)
_assert_called(model, 'train')
_assert_called(model, 'val')

trainer.validate(model)
_assert_called(model, 'val')

trainer.test(model)
_assert_called(model, 'test')


# todo: reconsider if it make sense to keep here
# def test_metric_lightning_log(tmpdir):
# """ Test logging a metric object and that the metric state gets reset after each epoch."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|buil
known_first_party = [
"torchmetrics",
"tests",
"integrations",
]
skip_glob = []
profile = "black"
Expand Down
7 changes: 5 additions & 2 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _LIGHTNING_GREATER_EQUAL_1_3, _TORCH_LOWER_1_6

seed_all(42)

Expand Down Expand Up @@ -101,7 +101,10 @@ def test_reset_compute():
a.update(tensor(5))
assert a.compute() == 5
a.reset()
assert a.compute() == 0
if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3:
assert a.compute() == 0
else:
assert a.compute() == 5


def test_update():
Expand Down
28 changes: 9 additions & 19 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_accuracy_differentiability(self, preds, target, subset_accuracy):
(_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True),
(_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True),
(_av_preds_ml, _av_target_ml, 5 / 8, None, False),
(_av_preds_ml, _av_target_ml, 0, None, True)
(_av_preds_ml, _av_target_ml, 0, None, True),
],
)
def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy):
Expand Down Expand Up @@ -209,18 +209,10 @@ def test_topk_accuracy_wrong_input_types(preds, target):
("micro", None, None, _input_mcls_prob, NUM_CLASSES, None, 0.5),
("micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES, None, 0.5),
(None, None, None, _input_mcls_prob, None, 0, 0.5),
(None, None, None, _input_mcls_prob, None, None, 1.5)
(None, None, None, _input_mcls_prob, None, None, 1.5),
],
)
def test_wrong_params(
average,
mdmc_average,
num_classes,
inputs,
ignore_index,
top_k,
threshold
):
def test_wrong_params(average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold):
preds, target = inputs.preds, inputs.target

with pytest.raises(ValueError):
Expand Down Expand Up @@ -250,14 +242,12 @@ def test_wrong_params(

@pytest.mark.parametrize(
"preds_mc, target_mc, preds_ml, target_ml",
[
(
tensor([0, 1, 1, 1]),
tensor([2, 2, 1, 1]),
tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]),
tensor([[1, 0, 1, 1], [0, 0, 1, 0]]),
)
],
[(
tensor([0, 1, 1, 1]),
tensor([2, 2, 1, 1]),
tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]),
tensor([[1, 0, 1, 1], [0, 0, 1, 0]]),
)],
)
def test_different_modes(preds_mc, target_mc, preds_ml, target_ml):
acc = Accuracy()
Expand Down
5 changes: 1 addition & 4 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def update(self, preds: Tensor, target: Tensor):
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""

""" returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """
mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass)

Expand All @@ -234,9 +233,7 @@ def update(self, preds: Tensor, target: Tensor):
self.subset_accuracy = False

if self.subset_accuracy:
correct, total = _subset_accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k,
)
correct, total = _subset_accuracy_update(preds, target, threshold=self.threshold, top_k=self.top_k)
self.correct += correct
self.total += total
else:
Expand Down
9 changes: 6 additions & 3 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _mode(
threshold: float,
top_k: Optional[int],
num_classes: Optional[int],
multiclass: Optional[bool]
multiclass: Optional[bool],
) -> DataType:
mode = _check_classification_inputs(
preds, target, threshold=threshold, top_k=top_k, num_classes=num_classes, multiclass=multiclass
Expand All @@ -49,7 +49,7 @@ def _accuracy_update(
top_k: Optional[int],
multiclass: Optional[bool],
ignore_index: Optional[int],
mode: DataType
mode: DataType,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
Expand Down Expand Up @@ -89,7 +89,10 @@ def _accuracy_compute(


def _subset_accuracy_update(
preds: Tensor, target: Tensor, threshold: float, top_k: Optional[int],
preds: Tensor,
target: Tensor,
threshold: float,
top_k: Optional[int],
) -> Tuple[Tensor, Tensor]:

preds, target = _input_squeeze(preds, target)
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _LIGHTNING_GREATER_EQUAL_1_3


class Metric(nn.Module, ABC):
Expand Down Expand Up @@ -256,7 +257,10 @@ def reset(self):
"""
This method automatically resets the metric state variables to their default value.
"""
self._computed = None
# lower lightning versions requires this implicitly to log metric objects correctly
# in self.log
if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3:
self._computed = None

for attr, default in self._defaults.items():
current_val = getattr(self, attr)
Expand Down
16 changes: 12 additions & 4 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import operator
from importlib import import_module
from importlib.util import find_spec
from typing import Optional

from packaging.version import Version
from pkg_resources import DistributionNotFound
from pkg_resources import DistributionNotFound, get_distribution


def _module_available(module_path: str) -> bool:
Expand All @@ -39,7 +40,7 @@ def _module_available(module_path: str) -> bool:
return False


def _compare_version(package: str, op, version) -> bool:
def _compare_version(package: str, op, version) -> Optional[bool]:
"""
Compare package version with some requirements
Expand All @@ -49,10 +50,15 @@ def _compare_version(package: str, op, version) -> bool:
"""
try:
pkg = import_module(package)
pkg_version = pkg.__version__
except (ModuleNotFoundError, DistributionNotFound):
return False
return None
except ImportError:
# catches cyclic imports - the case with integrated libs
# see: https://stackoverflow.com/a/32965521
pkg_version = get_distribution(package).version
try:
pkg_version = Version(pkg.__version__)
pkg_version = Version(pkg_version)
except TypeError:
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
Expand All @@ -64,3 +70,5 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_LIGHTNING_AVAILABLE = _module_available("pytorch_lightning")
_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")

0 comments on commit 7115789

Please sign in to comment.