From ee5ef5f3bf4bad7d19b803e348b8fb9b3c3b65a4 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 6 Jan 2022 15:05:06 +0100 Subject: [PATCH 01/12] Optionally Avoid recomputing features --- torchmetrics/image/fid.py | 17 ++++++++++++++++- torchmetrics/image/kid.py | 14 ++++++++++++++ torchmetrics/metric.py | 4 +++- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index a88033c92a1..f1f7b2add03 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -159,7 +159,10 @@ class FID(Metric): - an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size. - compute_on_step: + reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not + change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if + your dataset does not change. + compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` @@ -209,6 +212,7 @@ class FID(Metric): def __init__( self, feature: Union[int, torch.nn.Module] = 2048, + reset_real_features: bool = True, compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -248,6 +252,14 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) + if reset_real_features: + exclude_states = () + else: + exclude_states = ('real_features',) + + self._reset_excluded_states = exclude_states + + def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. @@ -282,3 +294,6 @@ def compute(self) -> Tensor: # compute fid return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype) + + def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: + super().reset(set([*self._reset_excluded_states, *exclude_states])) diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 6e691e06cb3..ddd1fbacdfa 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -110,6 +110,9 @@ class KID(Metric): Scale-length of polynomial kernel. If set to ``None`` will be automatically set to the feature size coef: Bias term in the polynomial kernel. + reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not + change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if + your dataset does not change. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: @@ -174,6 +177,7 @@ def __init__( degree: int = 3, gamma: Optional[float] = None, # type: ignore coef: float = 1.0, + reset_real_features: bool = True, compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -235,6 +239,13 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) + if reset_real_features: + exclude_states = () + else: + exclude_states = ('real_features',) + + self._reset_excluded_states = exclude_states + def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. @@ -276,3 +287,6 @@ def compute(self) -> Tuple[Tensor, Tensor]: kid_scores_.append(o) kid_scores = torch.stack(kid_scores_) return kid_scores.mean(), kid_scores.std(unbiased=False) + + def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: + super().reset(set([*self._reset_excluded_states, *exclude_states])) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 3b1211297c8..ac6cbfe3e77 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -394,7 +394,7 @@ def compute(self) -> Any: """Override this method to compute the final metric value from state variables synchronized across the distributed backend.""" - def reset(self) -> None: + def reset(self, exclude_states: Optional[Sequence['str']] = None) -> None: """This method automatically resets the metric state variables to their default value.""" self._update_called = False self._forward_cache = None @@ -403,6 +403,8 @@ def reset(self) -> None: self._computed = None for attr, default in self._defaults.items(): + if exclude_states is not None and attr in exclude_states: + continue current_val = getattr(self, attr) if isinstance(default, Tensor): setattr(self, attr, default.detach().clone().to(current_val.device)) From 4adc391688383e28cc782309e6a44dc37162ca02 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Jan 2022 14:06:36 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/image/fid.py | 9 ++++----- torchmetrics/image/kid.py | 6 +++--- torchmetrics/metric.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index f1f7b2add03..9046f09374f 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -160,9 +160,9 @@ class FID(Metric): an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size. reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not - change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if + change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if your dataset does not change. - compute_on_step: + compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` @@ -255,11 +255,10 @@ def __init__( if reset_real_features: exclude_states = () else: - exclude_states = ('real_features',) + exclude_states = ("real_features",) self._reset_excluded_states = exclude_states - def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. @@ -296,4 +295,4 @@ def compute(self) -> Tensor: return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype) def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - super().reset(set([*self._reset_excluded_states, *exclude_states])) + super().reset({*self._reset_excluded_states, *exclude_states}) diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index ddd1fbacdfa..f75d71a6fc8 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -111,7 +111,7 @@ class KID(Metric): coef: Bias term in the polynomial kernel. reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not - change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if + change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if your dataset does not change. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. @@ -242,7 +242,7 @@ def __init__( if reset_real_features: exclude_states = () else: - exclude_states = ('real_features',) + exclude_states = ("real_features",) self._reset_excluded_states = exclude_states @@ -289,4 +289,4 @@ def compute(self) -> Tuple[Tensor, Tensor]: return kid_scores.mean(), kid_scores.std(unbiased=False) def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - super().reset(set([*self._reset_excluded_states, *exclude_states])) + super().reset({*self._reset_excluded_states, *exclude_states}) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index ac6cbfe3e77..b4e2f751d1d 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -394,7 +394,7 @@ def compute(self) -> Any: """Override this method to compute the final metric value from state variables synchronized across the distributed backend.""" - def reset(self, exclude_states: Optional[Sequence['str']] = None) -> None: + def reset(self, exclude_states: Optional[Sequence["str"]] = None) -> None: """This method automatically resets the metric state variables to their default value.""" self._update_called = False self._forward_cache = None From 590cd0b4bf64d2318049be6700c77430bfc7f002 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 6 Jan 2022 15:08:04 +0100 Subject: [PATCH 03/12] Update metric.py --- torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index b4e2f751d1d..f15d85b169f 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -394,7 +394,7 @@ def compute(self) -> Any: """Override this method to compute the final metric value from state variables synchronized across the distributed backend.""" - def reset(self, exclude_states: Optional[Sequence["str"]] = None) -> None: + def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: """This method automatically resets the metric state variables to their default value.""" self._update_called = False self._forward_cache = None From 5f2600ddd95b98d5ef752134932d7cdd35dfde19 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 20 Jan 2022 17:29:10 +0100 Subject: [PATCH 04/12] Apply suggestions from code review Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/image/fid.py | 5 +---- torchmetrics/image/kid.py | 5 +---- torchmetrics/metric.py | 5 ++++- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 20dd842e0f3..7d64773fea2 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -255,10 +255,7 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - if reset_real_features: - exclude_states = () - else: - exclude_states = ("real_features",) + exclude_states = () if reset_real_features else ("real_features", ) self._reset_excluded_states = exclude_states diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 3a8a2bc00e9..69eaa1afff9 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -241,10 +241,7 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - if reset_real_features: - exclude_states = () - else: - exclude_states = ("real_features",) + exclude_states = () if reset_real_features else ("real_features", ) self._reset_excluded_states = exclude_states diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 6d66c8a3769..04932a75b04 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -394,7 +394,10 @@ def compute(self) -> Any: distributed backend.""" def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - """This method automatically resets the metric state variables to their default value.""" + """This method automatically resets the metric state variables to their default value. + Args: + exclude_stetes: sequence of strings indicating metric states that should not be reset. + """ self._update_called = False self._forward_cache = None # lower lightning versions requires this implicitly to log metric objects correctly in self.log From 89152b9333543446dce247b66d0fed85d4aaec57 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jan 2022 16:31:47 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/image/fid.py | 3 ++- torchmetrics/image/kid.py | 2 +- torchmetrics/metric.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 7d64773fea2..dbf8e36ba0d 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -255,7 +255,7 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - exclude_states = () if reset_real_features else ("real_features", ) + exclude_states = () if reset_real_features else ("real_features",) self._reset_excluded_states = exclude_states @@ -297,6 +297,7 @@ def compute(self) -> Tensor: def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: super().reset({*self._reset_excluded_states, *exclude_states}) + class FID(FrechetInceptionDistance): r""" Calculates Fréchet inception distance (FID_) which is used to access the quality of generated images. diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 69eaa1afff9..c1f976012ac 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -241,7 +241,7 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - exclude_states = () if reset_real_features else ("real_features", ) + exclude_states = () if reset_real_features else ("real_features",) self._reset_excluded_states = exclude_states diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 04932a75b04..85a19f08688 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -395,6 +395,7 @@ def compute(self) -> Any: def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: """This method automatically resets the metric state variables to their default value. + Args: exclude_stetes: sequence of strings indicating metric states that should not be reset. """ From 6b83c36ab8b7087c6b58995d22ddf23e04b4d212 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 21 Jan 2022 10:57:33 +0100 Subject: [PATCH 06/12] add exclude states to other collections/trackers --- torchmetrics/collections.py | 10 +++++++--- torchmetrics/metric.py | 2 +- torchmetrics/wrappers/minmax.py | 12 ++++++++---- torchmetrics/wrappers/multioutput.py | 10 +++++++--- torchmetrics/wrappers/tracker.py | 10 +++++++--- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index ad903fe906a..749994cb7db 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -124,10 +124,14 @@ def update(self, *args: Any, **kwargs: Any) -> None: def compute(self) -> Dict[str, Any]: return {k: m.compute() for k, m in self.items()} - def reset(self) -> None: - """Iteratively call reset for each metric.""" + def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: + """Iteratively call reset for each metric. + + Args: + exclude_states: sequence of strings indicating metric states that should not be reset. + """ for _, m in self.items(keep_base=True): - m.reset() + m.reset(exclude_states=exclude_states) def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": """Make a copy of the metric collection diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 85a19f08688..657bc7de0e6 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -397,7 +397,7 @@ def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: """This method automatically resets the metric state variables to their default value. Args: - exclude_stetes: sequence of strings indicating metric states that should not be reset. + exclude_states: sequence of strings indicating metric states that should not be reset. """ self._update_called = False self._forward_cache = None diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 8386516f7a3..48c1207a0fc 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -104,10 +104,14 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore self.min_val = val if self.min_val > val else self.min_val return {"raw": val, "max": self.max_val, "min": self.min_val} - def reset(self) -> None: - """Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric.""" - super().reset() - self._base_metric.reset() + def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: + """Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric. + + Args: + exclude_states: sequence of strings indicating metric states that should not be reset. + """ + super().reset(exclude_states) + self._base_metric.reset(exclude_states) @staticmethod def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index e8317b5a903..723e590aeec 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -160,7 +160,11 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return None return results - def reset(self) -> None: - """Reset all underlying metrics.""" + def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: + """Reset all underlying metrics. + + Args: + exclude_states: sequence of strings indicating metric states that should not be reset. + """ for metric in self.metrics: - metric.reset() + metric.reset(exclude_states) diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 551af66aa4f..50bd8fee54e 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -98,9 +98,13 @@ def compute_all(self) -> Tensor: self._check_for_increment("compute_all") return torch.stack([metric.compute() for i, metric in enumerate(self) if i != 0], dim=0) - def reset(self) -> None: - """Resets the current metric being tracked.""" - self[-1].reset() + def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: + """Resets the current metric being tracked. + + Args: + exclude_states: sequence of strings indicating metric states that should not be reset. + """ + self[-1].reset(exclude_states) def reset_all(self) -> None: """Resets all metrics being tracked.""" From bc7c07d3efb5c4cd35bcd03a109dd6070ebae447 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 15 Feb 2022 13:12:33 +0100 Subject: [PATCH 07/12] Update torchmetrics/image/fid.py --- torchmetrics/image/fid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 68963a437a2..290cb1a7640 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -298,4 +298,5 @@ def compute(self) -> Tensor: return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype) def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: + exclude_states = exclude_states or () super().reset({*self._reset_excluded_states, *exclude_states}) From 0827d8057256c0a2d81137b827a7a5affc1b80df Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 11 Apr 2022 09:56:44 +0200 Subject: [PATCH 08/12] revert --- torchmetrics/collections.py | 10 ++--- torchmetrics/metric.py | 10 +---- torchmetrics/wrappers/minmax.py | 12 ++---- torchmetrics/wrappers/multioutput.py | 10 ++--- torchmetrics/wrappers/tracker.py | 57 ++++++---------------------- 5 files changed, 24 insertions(+), 75 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index c3b65ccb225..3c8b58df1c7 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -226,14 +226,10 @@ def compute(self) -> Dict[str, Any]: res = _flatten_dict(res) return {self._set_name(k): v for k, v in res.items()} - def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - """Iteratively call reset for each metric. - - Args: - exclude_states: sequence of strings indicating metric states that should not be reset. - """ + def reset(self) -> None: + """Iteratively call reset for each metric.""" for _, m in self.items(keep_base=True): - m.reset(exclude_states=exclude_states) + m.reset() def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": """Make a copy of the metric collection diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index c93d59a9187..8efea6b72bf 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -433,19 +433,13 @@ def compute(self) -> Any: """Override this method to compute the final metric value from state variables synchronized across the distributed backend.""" - def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - """This method automatically resets the metric state variables to their default value. - - Args: - exclude_states: sequence of strings indicating metric states that should not be reset. - """ + def reset(self) -> None: + """This method automatically resets the metric state variables to their default value.""" self._update_called = False self._forward_cache = None self._computed = None for attr, default in self._defaults.items(): - if exclude_states is not None and attr in exclude_states: - continue current_val = getattr(self, attr) if isinstance(default, Tensor): setattr(self, attr, default.detach().clone().to(current_val.device)) diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 5f7095465de..f8451eb3058 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -95,14 +95,10 @@ def compute(self) -> Dict[str, Tensor]: # type: ignore self.min_val = val if self.min_val > val else self.min_val return {"raw": val, "max": self.max_val, "min": self.min_val} - def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - """Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric. - - Args: - exclude_states: sequence of strings indicating metric states that should not be reset. - """ - super().reset(exclude_states) - self._base_metric.reset(exclude_states) + def reset(self) -> None: + """Sets ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric.""" + super().reset() + self._base_metric.reset() @staticmethod def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 7c478baf510..b1f55d3c6d8 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -142,11 +142,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return None return results - def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - """Reset all underlying metrics. - - Args: - exclude_states: sequence of strings indicating metric states that should not be reset. - """ + def reset(self) -> None: + """Reset all underlying metrics.""" for metric in self.metrics: - metric.reset(exclude_states) + metric.reset() diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index c3bce1dffa9..273cb9019f5 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from copy import deepcopy from typing import Any, Dict, List, Tuple, Union @@ -142,13 +141,9 @@ def compute_all(self) -> Tensor: return {k: torch.stack([r[k] for r in res], dim=0) for k in keys} return torch.stack(res, dim=0) - def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - """Resets the current metric being tracked. - - Args: - exclude_states: sequence of strings indicating metric states that should not be reset. - """ - self[-1].reset(exclude_states) + def reset(self) -> None: + """Resets the current metric being tracked.""" + self[-1].reset() def reset_all(self) -> None: """Resets all metrics being tracked.""" @@ -157,14 +152,7 @@ def reset_all(self) -> None: def best_metric( self, return_step: bool = False - ) -> Union[ - None, - float, - Tuple[int, float], - Tuple[None, None], - Dict[str, Union[float, None]], - Tuple[Dict[str, Union[int, None]], Dict[str, Union[float, None]]], - ]: + ) -> Union[float, Tuple[int, float], Dict[str, float], Tuple[Dict[str, int], Dict[str, float]]]: """Returns the highest metric out of all tracked. Args: @@ -175,39 +163,18 @@ def best_metric( """ if isinstance(self._base_metric, Metric): fn = torch.max if self.maximize else torch.min - try: - idx, best = fn(self.compute_all(), 0) - if return_step: - return idx.item(), best.item() - return best.item() - except ValueError as error: - warnings.warn( - f"Encountered the following error when trying to get the best metric: {error}" - "this is probably due to the 'best' not being defined for this metric." - "Returning `None` instead.", - UserWarning, - ) - if return_step: - return None, None - return None - - else: # this is a metric collection + idx, best = fn(self.compute_all(), 0) + if return_step: + return idx.item(), best.item() + return best.item() + else: res = self.compute_all() maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize] idx, best = {}, {} for i, (k, v) in enumerate(res.items()): - try: - fn = torch.max if maximize[i] else torch.min - out = fn(v, 0) - idx[k], best[k] = out[0].item(), out[1].item() - except ValueError as error: - warnings.warn( - f"Encountered the following error when trying to get the best metric for metric {k}:" - f"{error} this is probably due to the 'best' not being defined for this metric." - "Returning `None` instead.", - UserWarning, - ) - idx[k], best[k] = None, None + fn = torch.max if maximize[i] else torch.min + out = fn(v, 0) + idx[k], best[k] = out[0].item(), out[1].item() if return_step: return idx, best From d13ca9e34a95fc8bbc306c1bc885e8bee7d72078 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 11 Apr 2022 09:57:59 +0200 Subject: [PATCH 09/12] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55742672139..8392b810502 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `adaptive_k` for the `RetrievalPrecision` metric ([#910](https://github.com/PyTorchLightning/metrics/pull/910)) +- Added `reset_real_features` argument image quality assesment metrics ([#722](https://github.com/PyTorchLightning/metrics/pull/722)) + + ### Changed - Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914)) From 8a1ebbbabf4e7a0893388eea1061db9c27365ce1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 11 Apr 2022 10:07:13 +0200 Subject: [PATCH 10/12] new implementation --- torchmetrics/image/fid.py | 25 +++++++++++++++++-------- torchmetrics/image/kid.py | 23 ++++++++++++++++------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 143783175ec..51986d3fee6 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -164,6 +164,7 @@ class FrechetInceptionDistance(Metric): reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if your dataset does not change. + compute_on_step: Forward only calls ``update()`` and returns None if this is set to False. @@ -189,6 +190,8 @@ class FrechetInceptionDistance(Metric): If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048] TypeError: If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module`` + ValueError: + If ``reset_real_features`` is not an ``bool`` Example: >>> import torch @@ -206,7 +209,8 @@ class FrechetInceptionDistance(Metric): """ real_features: List[Tensor] fake_features: List[Tensor] - higher_is_better = False + higher_is_better: bool = False + is_differentiable: bool = False def __init__( self, @@ -241,13 +245,13 @@ def __init__( else: raise TypeError("Got unknown input to argument `feature`") + if not isinstance(reset_real_features, bool): + raise ValueError("Arugment `reset_real_features` expected to be a bool") + self.reset_real_features = reset_real_features + self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - exclude_states = () if reset_real_features else ("real_features",) - - self._reset_excluded_states = exclude_states - def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. @@ -283,6 +287,11 @@ def compute(self) -> Tensor: # compute fid return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype) - def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - exclude_states = exclude_states or () - super().reset({*self._reset_excluded_states, *exclude_states}) + def reset(self) -> None: + if not self.reset_real_features: + # remove temporarily to avoid resetting + value = self._defaults.pop("real_features") + super().reset() + self._defaults["real_features"] = value + else: + super().reset() diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index b5db8183ed8..9bc70069f50 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -148,6 +148,8 @@ class KernelInceptionDistance(Metric): If ``gamma`` is niether ``None`` or a float larger than 0 ValueError: If ``coef`` is not an float larger than 0 + ValueError: + If ``reset_real_features`` is not an ``bool`` Example: >>> import torch @@ -166,7 +168,8 @@ class KernelInceptionDistance(Metric): """ real_features: List[Tensor] fake_features: List[Tensor] - higher_is_better = False + higher_is_better: bool = False + is_differentiable: bool = False def __init__( self, @@ -226,14 +229,14 @@ def __init__( raise ValueError("Argument `coef` expected to be float larger than 0") self.coef = coef + if not isinstance(reset_real_features, bool): + raise ValueError("Arugment `reset_real_features` expected to be a bool") + self.reset_real_features = reset_real_features + # states for extracted features self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - exclude_states = () if reset_real_features else ("real_features",) - - self._reset_excluded_states = exclude_states - def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. @@ -276,5 +279,11 @@ def compute(self) -> Tuple[Tensor, Tensor]: kid_scores = torch.stack(kid_scores_) return kid_scores.mean(), kid_scores.std(unbiased=False) - def reset(self, exclude_states: Optional[Sequence[str]] = None) -> None: - super().reset({*self._reset_excluded_states, *exclude_states}) + def reset(self) -> None: + if not self.reset_real_features: + # remove temporarily to avoid resetting + value = self._defaults.pop("real_features") + super().reset() + self._defaults["real_features"] = value + else: + super().reset() From e93758ced0a007805a06c296a7e5883d660d265a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 11 Apr 2022 10:15:55 +0200 Subject: [PATCH 11/12] add tests --- tests/image/test_fid.py | 25 +++++++++++++++++++++++++ tests/image/test_kid.py | 25 +++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/tests/image/test_fid.py b/tests/image/test_fid.py index 2c6d7455cf9..dd44af83c50 100644 --- a/tests/image/test_fid.py +++ b/tests/image/test_fid.py @@ -153,3 +153,28 @@ def test_compare_fid(tmpdir, feature=2048): tm_res = metric.compute() assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid["frechet_inception_distance"]]), atol=1e-3) + + +@pytest.mark.parametrize("reset_real_features", [True, False]) +def test_reset_real_features_arg(reset_real_features): + metric = FrechetInceptionDistance(feature=64, reset_real_features=reset_real_features) + + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True) + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False) + + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] + + assert len(metric.fake_features) == 1 + assert list(metric.fake_features[0].shape) == [2, 64] + + metric.reset() + + # fake features should always reset + assert len(metric.fake_features) == 0 + + if reset_real_features: + assert len(metric.real_features) == 0 + else: + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] diff --git a/tests/image/test_kid.py b/tests/image/test_kid.py index c9459bc57c5..dca29cd1c97 100644 --- a/tests/image/test_kid.py +++ b/tests/image/test_kid.py @@ -163,3 +163,28 @@ def test_compare_kid(tmpdir, feature=2048): assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid["kernel_inception_distance_mean"]]), atol=1e-3) assert torch.allclose(tm_std.cpu(), torch.tensor([torch_fid["kernel_inception_distance_std"]]), atol=1e-3) + + +@pytest.mark.parametrize("reset_real_features", [True, False]) +def test_reset_real_features_arg(reset_real_features): + metric = KernelInceptionDistance(feature=64, reset_real_features=reset_real_features) + + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True) + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False) + + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] + + assert len(metric.fake_features) == 1 + assert list(metric.fake_features[0].shape) == [2, 64] + + metric.reset() + + # fake features should always reset + assert len(metric.fake_features) == 0 + + if reset_real_features: + assert len(metric.real_features) == 0 + else: + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] From e817ad1ad79a4298cef182afcb9a72c7037824e5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 11 Apr 2022 10:20:05 +0200 Subject: [PATCH 12/12] revert --- torchmetrics/wrappers/tracker.py | 47 ++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 273cb9019f5..84544c1c1dc 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy from typing import Any, Dict, List, Tuple, Union @@ -152,7 +153,14 @@ def reset_all(self) -> None: def best_metric( self, return_step: bool = False - ) -> Union[float, Tuple[int, float], Dict[str, float], Tuple[Dict[str, int], Dict[str, float]]]: + ) -> Union[ + None, + float, + Tuple[int, float], + Tuple[None, None], + Dict[str, Union[float, None]], + Tuple[Dict[str, Union[int, None]], Dict[str, Union[float, None]]], + ]: """Returns the highest metric out of all tracked. Args: @@ -163,18 +171,39 @@ def best_metric( """ if isinstance(self._base_metric, Metric): fn = torch.max if self.maximize else torch.min - idx, best = fn(self.compute_all(), 0) - if return_step: - return idx.item(), best.item() - return best.item() - else: + try: + idx, best = fn(self.compute_all(), 0) + if return_step: + return idx.item(), best.item() + return best.item() + except ValueError as error: + warnings.warn( + f"Encountered the following error when trying to get the best metric: {error}" + "this is probably due to the 'best' not being defined for this metric." + "Returning `None` instead.", + UserWarning, + ) + if return_step: + return None, None + return None + + else: # this is a metric collection res = self.compute_all() maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize] idx, best = {}, {} for i, (k, v) in enumerate(res.items()): - fn = torch.max if maximize[i] else torch.min - out = fn(v, 0) - idx[k], best[k] = out[0].item(), out[1].item() + try: + fn = torch.max if maximize[i] else torch.min + out = fn(v, 0) + idx[k], best[k] = out[0].item(), out[1].item() + except ValueError as error: + warnings.warn( + f"Encountered the following error when trying to get the best metric for metric {k}:" + f"{error} this is probably due to the 'best' not being defined for this metric." + "Returning `None` instead.", + UserWarning, + ) + idx[k], best[k] = None, None if return_step: return idx, best