diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index eab5bd0b86a..e075cb9a09f 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -128,6 +128,7 @@ jobs: fail_ci_if_error: false - name: Integrations + if: success() run: | pip install -r requirements/integrate.txt --quiet --upgrade-strategy only-if-needed --find-links $PYTORCH_URL pip uninstall -y torchmetrics diff --git a/integrations/test_lightning.py b/integrations/test_lightning.py index 34dd8e07bb2..379153a5423 100644 --- a/integrations/test_lightning.py +++ b/integrations/test_lightning.py @@ -21,31 +21,12 @@ from integrations.lightning.boring_model import BoringModel, RandomDataset from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3 -from torchmetrics import Accuracy, AveragePrecision, Metric +from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric -class SumMetric(Metric): - def __init__(self): - super().__init__() - self.add_state("x", tensor(0.0), dist_reduce_fx="sum") - - def update(self, x): - self.x += x - - def compute(self): - return self.x - - -class DiffMetric(Metric): - def __init__(self): - super().__init__() - self.add_state("x", tensor(0.0), dist_reduce_fx="sum") - - def update(self, x): - self.x -= x - - def compute(self): - return self.x +class DiffMetric(SumMetric): + def update(self, value): + super().update(-value) def test_metric_lightning(tmpdir): @@ -201,45 +182,81 @@ def _assert_called(model, stage): _assert_called(model, "test") -# todo: reconsider if it make sense to keep here -# def test_metric_lightning_log(tmpdir): -# """ Test logging a metric object and that the metric state gets reset after each epoch.""" -# class TestModel(BoringModel): -# def __init__(self): -# super().__init__() -# self.metric_step = SumMetric() -# self.metric_epoch = SumMetric() -# self.sum = 0.0 -# -# def on_epoch_start(self): -# self.sum = 0.0 -# -# def training_step(self, batch, batch_idx): -# x = batch -# self.metric_step(x.sum()) -# self.sum += x.sum() -# self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) -# return {'loss': self.step(x), 'data': x} -# -# def training_epoch_end(self, outs): -# self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum())) -# -# model = TestModel() -# model.val_dataloader = None -# -# trainer = Trainer( -# default_root_dir=tmpdir, -# limit_train_batches=2, -# limit_val_batches=2, -# max_epochs=2, -# log_every_n_steps=1, -# weights_summary=None, -# ) -# trainer.fit(model) -# -# logged = trainer.logged_metrics -# assert torch.allclose(tensor(logged["sum_step"]), model.sum) -# assert torch.allclose(tensor(logged["sum_epoch"]), model.sum) +def test_metric_lightning_log(tmpdir): + """Test logging a metric object and that the metric state gets reset after each epoch.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.metric_step = SumMetric() + self.metric_epoch = SumMetric() + self.sum = 0.0 + + def on_epoch_start(self): + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric_step(x.sum()) + self.sum += x.sum() + self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) + return {"loss": self.step(x), "data": x} + + def training_epoch_end(self, outs): + self.log("sum_epoch", self.metric_epoch(torch.stack([o["data"] for o in outs]).sum())) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=0, + max_epochs=2, + log_every_n_steps=1, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + assert torch.allclose(tensor(logged["sum_step"]), model.sum) + assert torch.allclose(tensor(logged["sum_epoch"]), model.sum) + + +def test_metric_collection_lightning_log(tmpdir): + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.metric = MetricCollection([SumMetric(), DiffMetric()]) + self.sum = 0.0 + self.diff = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + metric_vals = self.metric(x.sum()) + self.sum += x.sum() + self.diff -= x.sum() + self.log_dict({f"{k}_step": v for k, v in metric_vals.items()}) + return self.step(x) + + def training_epoch_end(self, outputs): + metric_vals = self.metric.compute() + self.log_dict({f"{k}_epoch": v for k, v in metric_vals.items()}) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=0, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum) + assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff) + # todo: need to be fixed # def test_scriptable(tmpdir): @@ -278,42 +295,3 @@ def _assert_called(model, stage): # output = model(rand_input) # script_output = script_model(rand_input) # assert torch.allclose(output, script_output) - -# def test_metric_collection_lightning_log(tmpdir): -# -# class TestModel(BoringModel): -# -# def __init__(self): -# super().__init__() -# self.metric = MetricCollection([SumMetric(), DiffMetric()]) -# self.sum = 0.0 -# self.diff = 0.0 -# -# def training_step(self, batch, batch_idx): -# x = batch -# metric_vals = self.metric(x.sum()) -# self.sum += x.sum() -# self.diff -= x.sum() -# self.log_dict({f'{k}_step': v for k, v in metric_vals.items()}) -# return self.step(x) -# -# def training_epoch_end(self, outputs): -# metric_vals = self.metric.compute() -# self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()}) -# -# model = TestModel() -# model.val_dataloader = None -# -# trainer = Trainer( -# default_root_dir=tmpdir, -# limit_train_batches=2, -# limit_val_batches=2, -# max_epochs=1, -# log_every_n_steps=1, -# weights_summary=None, -# ) -# trainer.fit(model) -# -# logged = trainer.logged_metrics -# assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum) -# assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff) diff --git a/requirements/integrate.txt b/requirements/integrate.txt index 5c2802a7f46..3acf3f8e78a 100644 --- a/requirements/integrate.txt +++ b/requirements/integrate.txt @@ -1 +1 @@ -pytorch-lightning>=1.0 +pytorch-lightning>=1.3 diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index c9a149ab441..95660715f2c 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -67,7 +67,6 @@ from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision from torchmetrics.functional.retrieval.recall import retrieval_recall from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank -from torchmetrics.functional.text.bert import bert_score from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.functional.text.chrf import chrf_score @@ -80,13 +79,16 @@ from torchmetrics.functional.text.wer import wer, word_error_rate from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.functional.text.wip import word_information_preserved +from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE + +if _TRANSFORMERS_AUTO_AVAILABLE: + from torchmetrics.functional.text.bert import bert_score # noqa: F401 __all__ = [ "accuracy", "auc", "auroc", "average_precision", - "bert_score", "bleu_score", "calibration_error", "chrf_score", diff --git a/torchmetrics/functional/text/__init__.py b/torchmetrics/functional/text/__init__.py index 2430ec68e4c..93ec450f8d0 100644 --- a/torchmetrics/functional/text/__init__.py +++ b/torchmetrics/functional/text/__init__.py @@ -22,3 +22,7 @@ from torchmetrics.functional.text.wer import wer, word_error_rate # noqa: F401 from torchmetrics.functional.text.wil import word_information_lost # noqa: F401 from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401 +from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE + +if _TRANSFORMERS_AUTO_AVAILABLE: + from torchmetrics.functional.text.bert import bert_score # noqa: F401 diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 12d7372c481..1792a86d1d9 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -24,10 +24,10 @@ from torch.utils.data import DataLoader, Dataset from torchmetrics.utilities import _future_warning -from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE -if _TRANSFORMERS_AVAILABLE: - from transformers import AutoModel, AutoTokenizer +if _TRANSFORMERS_AUTO_AVAILABLE: + from transformers.models.auto import AutoModel, AutoTokenizer if _TQDM_AVAILABLE: import tqdm @@ -580,7 +580,7 @@ def bert_score( ) if model is None: - if not _TRANSFORMERS_AVAILABLE: + if not _TRANSFORMERS_AUTO_AVAILABLE: raise ModuleNotFoundError( "`bert_score` metric with default models requires `transformers` package be installed." " Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`." diff --git a/torchmetrics/text/__init__.py b/torchmetrics/text/__init__.py index 49f73b7ad71..b88139f510a 100644 --- a/torchmetrics/text/__init__.py +++ b/torchmetrics/text/__init__.py @@ -22,3 +22,7 @@ from torchmetrics.text.wer import WER, WordErrorRate # noqa: F401 from torchmetrics.text.wil import WordInfoLost # noqa: F401 from torchmetrics.text.wip import WordInfoPreserved # noqa: F401 +from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE + +if _TRANSFORMERS_AUTO_AVAILABLE: + from torchmetrics.text.bert import BERTScore # noqa: F401 diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index aec0c39e949..5a43205bd95 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -17,14 +17,13 @@ import torch from deprecate import deprecated -from torchmetrics.functional import bert_score -from torchmetrics.functional.text.bert import _preprocess_text +from torchmetrics.functional.text.bert import _preprocess_text, bert_score from torchmetrics.metric import Metric from torchmetrics.utilities import _future_warning -from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE -if _TRANSFORMERS_AVAILABLE: - from transformers import AutoTokenizer +if _TRANSFORMERS_AUTO_AVAILABLE: + from transformers.models.auto import AutoTokenizer # Default model recommended in the original implementation. @@ -178,7 +177,7 @@ def __init__( self.tokenizer = user_tokenizer self.user_tokenizer = True else: - if not _TRANSFORMERS_AVAILABLE: + if not _TRANSFORMERS_AUTO_AVAILABLE: raise ModuleNotFoundError( "`BERTScore` metric with default tokenizers requires `transformers` package be installed." " Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`." diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index fdd74f4394b..220882863c9 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -27,6 +27,8 @@ def _module_available(module_path: str) -> bool: >>> _module_available('os') True + >>> _module_available('os.bla') + False >>> _module_available('bla.bla') False """ @@ -35,7 +37,7 @@ def _module_available(module_path: str) -> bool: except AttributeError: # Python 3.6 return False - except ModuleNotFoundError: + except (ImportError, ModuleNotFoundError): # Python 3.7+ return False @@ -55,7 +57,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] pkg_version = pkg.__version__ # type: ignore except (ModuleNotFoundError, DistributionNotFound): return None - except ImportError: + except (ImportError, AttributeError): # catches cyclic imports - the case with integrated libs # see: https://stackoverflow.com/a/32965521 pkg_version = get_distribution(package).version @@ -87,6 +89,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _TORCHVISION_GREATER_EQUAL_0_8: Optional[bool] = _compare_version("torchvision", operator.ge, "0.8.0") _TQDM_AVAILABLE: bool = _module_available("tqdm") _TRANSFORMERS_AVAILABLE: bool = _module_available("transformers") +_TRANSFORMERS_AUTO_AVAILABLE = _module_available("transformers.models.auto") _PESQ_AVAILABLE: bool = _module_available("pesq") _SACREBLEU_AVAILABLE: bool = _module_available("sacrebleu") _REGEX_AVAILABLE: bool = _module_available("regex")