Skip to content

Commit

Permalink
Remove lightning legacy code and references (#788)
Browse files Browse the repository at this point in the history
* remove lightning
* more removal
* fix test
* flake8
* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Aki Nitta <nitta@akihironitta.com>
  • Loading branch information
5 people authored Feb 5, 2022
1 parent 385b18e commit b225889
Show file tree
Hide file tree
Showing 22 changed files with 83 additions and 99 deletions.
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ help you or finish it with you :\]_

1. Add/update the relevant tests!

- [This PR](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241) is a good example for adding a new metric
- [This PR](https://github.com/PyTorchLightning/metrics/pull/98) is a good example for adding a new metric

### Test cases:

Expand Down
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Fixes #\<issue_number>
## Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section?
- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/metrics/blob/master/.github/CONTRIBUTING.md), Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

- Removed support for versions of Lightning lower than v1.5 ([#788](https://github.com/PyTorchLightning/metrics/pull/788))


- Removed deprecated functions, and warnings in Text ([#773](https://github.com/PyTorchLightning/metrics/pull/773))
* `functional.wer`
* `WER`
Expand Down
2 changes: 1 addition & 1 deletion docs/paper_JOSS/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ In addition to stateful metrics (called modular metrics in TorchMetrics), we als

TorchMetrics exhibits high test coverage on the various configurations, including all three major OS platforms (Linux, macOS, and Windows), and various Python, CUDA, and PyTorch versions. We test both minimum and latest package requirements for all combinations of OS and Python versions and include additional tests for each PyTorch version from 1.3 up to future development versions. On every pull request and merge to master, we run a full test suite. All standard tests run on CPU. In addition, we run all tests on a multi-GPU setting which reflects realistic Deep Learning workloads. For usability, we have auto-generated HTML documentation (hosted at [readthedocs](https://torchmetrics.readthedocs.io/en/stable/)) from the source code which updates in real-time with new merged pull requests.

TorchMetrics is released under the Apache 2.0 license. The source code is available at https://github.com/PytorchLightning/metrics.
TorchMetrics is released under the Apache 2.0 license. The source code is available at https://github.com/PyTorchLightning/metrics.

# Acknowledgement

Expand Down
5 changes: 3 additions & 2 deletions docs/source/governance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ Project Management and Decision Making
**************************************

The decision what goes into a release is governed by the :ref:`staff contributors and leaders <governance>` of
Lightning development. Whenever possible, discussion happens publicly on GitHub and includes the whole community.
When a consensus is reached, staff and core contributors assign milestones and labels to the issue and/or pull request and start tracking the development. It is possible that priorities change over time.
TorchMetrics development. Whenever possible, discussion happens publicly on GitHub and includes the whole community.
When a consensus is reached, staff and core contributors assign milestones and labels to the issue and/or pull request
and start tracking the development. It is possible that priorities change over time.

Commits to the project are exclusively to be added by pull requests on GitHub and anyone in the community is welcome to review them.
However, reviews submitted by
Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Internal implementation details
-------------------------------

This section briefly describes how metrics work internally. We encourage looking at the source code for more info.
Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
Internally, TorchMetrics wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the
following internally:

Expand Down
4 changes: 4 additions & 0 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ TorchMetrics in PyTorch Lightning

TorchMetrics was originaly created as part of `PyTorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_, a powerful deep learning research framework designed for scaling models without boilerplate.

..note::
TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend to always keep both frameworks
up-to-date for the best experience.

While TorchMetrics was built to be used with native PyTorch, using TorchMetrics with Lightning offers additional benefits:

* Module metrics are automatically placed on the correct device when properly defined inside a LightningModule. This means that your data will always be placed on the same device as your metrics.
Expand Down
5 changes: 0 additions & 5 deletions integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import operator
import os

from torchmetrics.utilities.imports import _compare_version

_INTEGRATION_ROOT = os.path.realpath(os.path.dirname(__file__))
_PACKAGE_ROOT = os.path.dirname(_INTEGRATION_ROOT)
_PATH_DATASETS = os.path.join(_PACKAGE_ROOT, "datasets")

_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
80 changes: 40 additions & 40 deletions integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.
from unittest import mock

import pytest
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import tensor
from torch.utils.data import DataLoader

from integrations.lightning.boring_model import BoringModel, RandomDataset
from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3
from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric


Expand Down Expand Up @@ -63,7 +61,6 @@ def training_epoch_end(self, outs):
trainer.fit(model)


@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason="test requires lightning v1.3 or higher")
def test_metrics_reset(tmpdir):
"""Tests that metrics are reset correctly after the end of the train/val/test epoch.
Expand Down Expand Up @@ -222,6 +219,8 @@ def training_epoch_end(self, outs):


def test_metric_collection_lightning_log(tmpdir):
"""Test that MetricCollection works with Lightning modules."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -258,40 +257,41 @@ def training_epoch_end(self, outputs):
assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff)


# todo: need to be fixed
# def test_scriptable(tmpdir):
# class TestModel(BoringModel):
# def __init__(self):
# super().__init__()
# # the metric is not used in the module's `forward`
# # so the module should be exportable to TorchScript
# self.metric = SumMetric()
# self.sum = 0.0
#
# def training_step(self, batch, batch_idx):
# x = batch
# self.metric(x.sum())
# self.sum += x.sum()
# self.log("sum", self.metric, on_epoch=True, on_step=False)
# return self.step(x)
#
# model = TestModel()
# trainer = Trainer(
# default_root_dir=tmpdir,
# limit_train_batches=2,
# limit_val_batches=2,
# max_epochs=1,
# log_every_n_steps=1,
# weights_summary=None,
# logger=False,
# checkpoint_callback=False,
# )
# trainer.fit(model)
# rand_input = torch.randn(10, 32)
#
# script_model = model.to_torchscript()
#
# # test that we can still do inference
# output = model(rand_input)
# script_output = script_model(rand_input)
# assert torch.allclose(output, script_output)
def test_scriptable(tmpdir):
"""Test that lightning modules can still be scripted even if metrics cannot."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
# the metric is not used in the module's `forward`
# so the module should be exportable to TorchScript
self.metric = SumMetric()
self.sum = 0.0

def training_step(self, batch, batch_idx):
x = batch
self.metric(x.sum())
self.sum += x.sum()
self.log("sum", self.metric, on_epoch=True, on_step=False)
return self.step(x)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
logger=False,
checkpoint_callback=False,
)
trainer.fit(model)
rand_input = torch.randn(10, 32)

script_model = model.to_torchscript()

# test that we can still do inference
output = model(rand_input)
script_output = script_model(rand_input)
assert torch.allclose(output, script_output)
3 changes: 0 additions & 3 deletions requirements/devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,3 @@
-r image_test.txt
-r text_test.txt
-r audio_test.txt

# add the integration dependencies
#-r integrate.txt
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ sphinx-togglebutton>=0.2
sphinx-copybutton>=0.3

# integrations
pytorch-lightning>=1.1
-r integrate.txt
2 changes: 1 addition & 1 deletion requirements/integrate.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pytorch-lightning>=1.3
pytorch-lightning>=1.5
9 changes: 3 additions & 6 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import torch
from torch import Tensor, nn, tensor

from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all
from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6

seed_all(42)

Expand Down Expand Up @@ -100,10 +100,7 @@ def test_reset_compute():
a.update(tensor(5))
assert a.compute() == 5
a.reset()
if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3:
assert a.compute() == 0
else:
assert a.compute() == 5
assert a.compute() == 0


def test_update():
Expand Down
5 changes: 1 addition & 4 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import operator
import random

import numpy
import torch

from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6, _compare_version
from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6

_MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4, reason="required PT >= 1.4")
_MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5, reason="required PT >= 1.5")
_MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6, reason="required PT >= 1.6")

_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")


def seed_all(seed):
random.seed(seed)
Expand Down
18 changes: 9 additions & 9 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def _class_test(
check_scriptable: bool = True,
**kwargs_update: Any,
):
"""Utility function doing the actual comparison between lightning class metric and reference metric.
"""Utility function doing the actual comparison between class metric and reference metric.
Args:
rank: rank of current process
worldsize: number of processes
preds: torch tensor with predictions
target: torch tensor with targets
metric_class: lightning metric class that should be tested
metric_class: metric class that should be tested
sk_metric: callable function that is used for comparison
dist_sync_on_step: bool, if true will synchronize metric state across
processes at each ``forward()``
Expand All @@ -150,7 +150,7 @@ def _class_test(
if not metric_args:
metric_args = {}

# Instantiate lightning metric
# Instantiate metric
metric = metric_class(
compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, **metric_args
)
Expand Down Expand Up @@ -255,12 +255,12 @@ def _functional_test(
fragment_kwargs: bool = False,
**kwargs_update,
):
"""Utility function doing the actual comparison between lightning functional metric and reference metric.
"""Utility function doing the actual comparison between functional metric and reference metric.
Args:
preds: torch tensor with predictions
target: torch tensor with targets
metric_functional: lightning metric functional that should be tested
metric_functional: metric functional that should be tested
sk_metric: callable function that is used for comparison
metric_args: dict with additional arguments used for class initialization
device: determine which device to run on, either 'cuda' or 'cpu'
Expand All @@ -283,15 +283,15 @@ def _functional_test(

for i in range(num_batches):
extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
lightning_result = metric(preds[i], target[i], **extra_kwargs)
tm_result = metric(preds[i], target[i], **extra_kwargs)
extra_kwargs = {
k: v.cpu() if isinstance(v, Tensor) else v
for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items()
}
sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs)

# assert its the same
_assert_allclose(lightning_result, sk_result, atol=atol)
_assert_allclose(tm_result, sk_result, atol=atol)


def _assert_half_support(
Expand Down Expand Up @@ -366,7 +366,7 @@ def run_functional_metric_test(
Args:
preds: torch tensor with predictions
target: torch tensor with targets
metric_functional: lightning metric class that should be tested
metric_functional: metric class that should be tested
sk_metric: callable function that is used for comparison
metric_args: dict with additional arguments used for class initialization
fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
Expand Down Expand Up @@ -408,7 +408,7 @@ def run_class_metric_test(
ddp: bool, if running in ddp mode or not
preds: torch tensor with predictions
target: torch tensor with targets
metric_class: lightning metric class that should be tested
metric_class: metric class that should be tested
sk_metric: callable function that is used for comparison
dist_sync_on_step: bool, if true will synchronize metric state across
processes at each ``forward()``
Expand Down
4 changes: 0 additions & 4 deletions tests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def _single_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

# `sk_target` and `sk_preds` switched to fix failing tests.
# For more info, check https://github.com/PyTorchLightning/metrics/pull/248#issuecomment-841232277
res = sk_fn(sk_target, sk_preds)

return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res
Expand All @@ -75,8 +73,6 @@ def _multi_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()

# `sk_target` and `sk_preds` switched to fix failing tests.
# For more info, check https://github.com/PyTorchLightning/metrics/pull/248#issuecomment-841232277
res = sk_fn(sk_target, sk_preds)

return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res
Expand Down
4 changes: 2 additions & 2 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _errors_test_class_metric(
indexes: torch tensor with indexes
preds: torch tensor with predictions
target: torch tensor with targets
metric_class: lightning metric class that should be tested
metric_class: metric class that should be tested
message: message that exception should return
metric_args: arguments for class initialization
exception_type: callable function that is used for comparison
Expand All @@ -396,7 +396,7 @@ def _errors_test_functional_metric(
Args:
preds: torch tensor with predictions
target: torch tensor with targets
metric_functional: lightning functional metric that should be tested
metric_functional: functional metric that should be tested
message: message that exception should return
exception_type: callable function that is used for comparison
kwargs_update: Additional keyword arguments that will be passed with indexes, preds and
Expand Down
Loading

0 comments on commit b225889

Please sign in to comment.