Skip to content

Commit

Permalink
enable more PL integrations (#739)
Browse files Browse the repository at this point in the history
* enable more Pl integrations
* if: success()
* try 1.1
* try 1.2
* try 1.3
* fix imports
* use own sum/diff
* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
4 people authored Jan 14, 2022
1 parent 7595120 commit 07cca4d
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 116 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ jobs:
fail_ci_if_error: false

- name: Integrations
if: success()
run: |
pip install -r requirements/integrate.txt --quiet --upgrade-strategy only-if-needed --find-links $PYTORCH_URL
pip uninstall -y torchmetrics
Expand Down
180 changes: 79 additions & 101 deletions integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,12 @@

from integrations.lightning.boring_model import BoringModel, RandomDataset
from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3
from torchmetrics import Accuracy, AveragePrecision, Metric
from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric


class SumMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")

def update(self, x):
self.x += x

def compute(self):
return self.x


class DiffMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")

def update(self, x):
self.x -= x

def compute(self):
return self.x
class DiffMetric(SumMetric):
def update(self, value):
super().update(-value)


def test_metric_lightning(tmpdir):
Expand Down Expand Up @@ -201,45 +182,81 @@ def _assert_called(model, stage):
_assert_called(model, "test")


# todo: reconsider if it make sense to keep here
# def test_metric_lightning_log(tmpdir):
# """ Test logging a metric object and that the metric state gets reset after each epoch."""
# class TestModel(BoringModel):
# def __init__(self):
# super().__init__()
# self.metric_step = SumMetric()
# self.metric_epoch = SumMetric()
# self.sum = 0.0
#
# def on_epoch_start(self):
# self.sum = 0.0
#
# def training_step(self, batch, batch_idx):
# x = batch
# self.metric_step(x.sum())
# self.sum += x.sum()
# self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
# return {'loss': self.step(x), 'data': x}
#
# def training_epoch_end(self, outs):
# self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum()))
#
# model = TestModel()
# model.val_dataloader = None
#
# trainer = Trainer(
# default_root_dir=tmpdir,
# limit_train_batches=2,
# limit_val_batches=2,
# max_epochs=2,
# log_every_n_steps=1,
# weights_summary=None,
# )
# trainer.fit(model)
#
# logged = trainer.logged_metrics
# assert torch.allclose(tensor(logged["sum_step"]), model.sum)
# assert torch.allclose(tensor(logged["sum_epoch"]), model.sum)
def test_metric_lightning_log(tmpdir):
"""Test logging a metric object and that the metric state gets reset after each epoch."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric_step = SumMetric()
self.metric_epoch = SumMetric()
self.sum = 0.0

def on_epoch_start(self):
self.sum = 0.0

def training_step(self, batch, batch_idx):
x = batch
self.metric_step(x.sum())
self.sum += x.sum()
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
return {"loss": self.step(x), "data": x}

def training_epoch_end(self, outs):
self.log("sum_epoch", self.metric_epoch(torch.stack([o["data"] for o in outs]).sum()))

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=1,
)
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(tensor(logged["sum_step"]), model.sum)
assert torch.allclose(tensor(logged["sum_epoch"]), model.sum)


def test_metric_collection_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = MetricCollection([SumMetric(), DiffMetric()])
self.sum = 0.0
self.diff = 0.0

def training_step(self, batch, batch_idx):
x = batch
metric_vals = self.metric(x.sum())
self.sum += x.sum()
self.diff -= x.sum()
self.log_dict({f"{k}_step": v for k, v in metric_vals.items()})
return self.step(x)

def training_epoch_end(self, outputs):
metric_vals = self.metric.compute()
self.log_dict({f"{k}_epoch": v for k, v in metric_vals.items()})

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum)
assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff)


# todo: need to be fixed
# def test_scriptable(tmpdir):
Expand Down Expand Up @@ -278,42 +295,3 @@ def _assert_called(model, stage):
# output = model(rand_input)
# script_output = script_model(rand_input)
# assert torch.allclose(output, script_output)

# def test_metric_collection_lightning_log(tmpdir):
#
# class TestModel(BoringModel):
#
# def __init__(self):
# super().__init__()
# self.metric = MetricCollection([SumMetric(), DiffMetric()])
# self.sum = 0.0
# self.diff = 0.0
#
# def training_step(self, batch, batch_idx):
# x = batch
# metric_vals = self.metric(x.sum())
# self.sum += x.sum()
# self.diff -= x.sum()
# self.log_dict({f'{k}_step': v for k, v in metric_vals.items()})
# return self.step(x)
#
# def training_epoch_end(self, outputs):
# metric_vals = self.metric.compute()
# self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()})
#
# model = TestModel()
# model.val_dataloader = None
#
# trainer = Trainer(
# default_root_dir=tmpdir,
# limit_train_batches=2,
# limit_val_batches=2,
# max_epochs=1,
# log_every_n_steps=1,
# weights_summary=None,
# )
# trainer.fit(model)
#
# logged = trainer.logged_metrics
# assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum)
# assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff)
2 changes: 1 addition & 1 deletion requirements/integrate.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pytorch-lightning>=1.0
pytorch-lightning>=1.3
6 changes: 4 additions & 2 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.functional.text.bert import bert_score
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.functional.text.cer import char_error_rate
from torchmetrics.functional.text.chrf import chrf_score
Expand All @@ -80,13 +79,16 @@
from torchmetrics.functional.text.wer import wer, word_error_rate
from torchmetrics.functional.text.wil import word_information_lost
from torchmetrics.functional.text.wip import word_information_preserved
from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
from torchmetrics.functional.text.bert import bert_score # noqa: F401

__all__ = [
"accuracy",
"auc",
"auroc",
"average_precision",
"bert_score",
"bleu_score",
"calibration_error",
"chrf_score",
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@
from torchmetrics.functional.text.wer import wer, word_error_rate # noqa: F401
from torchmetrics.functional.text.wil import word_information_lost # noqa: F401
from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401
from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
from torchmetrics.functional.text.bert import bert_score # noqa: F401
8 changes: 4 additions & 4 deletions torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from torch.utils.data import DataLoader, Dataset

from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from transformers import AutoModel, AutoTokenizer
if _TRANSFORMERS_AUTO_AVAILABLE:
from transformers.models.auto import AutoModel, AutoTokenizer

if _TQDM_AVAILABLE:
import tqdm
Expand Down Expand Up @@ -580,7 +580,7 @@ def bert_score(
)

if model is None:
if not _TRANSFORMERS_AVAILABLE:
if not _TRANSFORMERS_AUTO_AVAILABLE:
raise ModuleNotFoundError(
"`bert_score` metric with default models requires `transformers` package be installed."
" Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`."
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@
from torchmetrics.text.wer import WER, WordErrorRate # noqa: F401
from torchmetrics.text.wil import WordInfoLost # noqa: F401
from torchmetrics.text.wip import WordInfoPreserved # noqa: F401
from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
from torchmetrics.text.bert import BERTScore # noqa: F401
11 changes: 5 additions & 6 deletions torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
import torch
from deprecate import deprecated

from torchmetrics.functional import bert_score
from torchmetrics.functional.text.bert import _preprocess_text
from torchmetrics.functional.text.bert import _preprocess_text, bert_score
from torchmetrics.metric import Metric
from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from transformers import AutoTokenizer
if _TRANSFORMERS_AUTO_AVAILABLE:
from transformers.models.auto import AutoTokenizer


# Default model recommended in the original implementation.
Expand Down Expand Up @@ -178,7 +177,7 @@ def __init__(
self.tokenizer = user_tokenizer
self.user_tokenizer = True
else:
if not _TRANSFORMERS_AVAILABLE:
if not _TRANSFORMERS_AUTO_AVAILABLE:
raise ModuleNotFoundError(
"`BERTScore` metric with default tokenizers requires `transformers` package be installed."
" Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`."
Expand Down
7 changes: 5 additions & 2 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _module_available(module_path: str) -> bool:
>>> _module_available('os')
True
>>> _module_available('os.bla')
False
>>> _module_available('bla.bla')
False
"""
Expand All @@ -35,7 +37,7 @@ def _module_available(module_path: str) -> bool:
except AttributeError:
# Python 3.6
return False
except ModuleNotFoundError:
except (ImportError, ModuleNotFoundError):
# Python 3.7+
return False

Expand All @@ -55,7 +57,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]
pkg_version = pkg.__version__ # type: ignore
except (ModuleNotFoundError, DistributionNotFound):
return None
except ImportError:
except (ImportError, AttributeError):
# catches cyclic imports - the case with integrated libs
# see: https://stackoverflow.com/a/32965521
pkg_version = get_distribution(package).version
Expand Down Expand Up @@ -87,6 +89,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]
_TORCHVISION_GREATER_EQUAL_0_8: Optional[bool] = _compare_version("torchvision", operator.ge, "0.8.0")
_TQDM_AVAILABLE: bool = _module_available("tqdm")
_TRANSFORMERS_AVAILABLE: bool = _module_available("transformers")
_TRANSFORMERS_AUTO_AVAILABLE = _module_available("transformers.models.auto")
_PESQ_AVAILABLE: bool = _module_available("pesq")
_SACREBLEU_AVAILABLE: bool = _module_available("sacrebleu")
_REGEX_AVAILABLE: bool = _module_available("regex")
Expand Down

0 comments on commit 07cca4d

Please sign in to comment.