From 8571c18d5208b8c25669e6cdccc4c137cc3e5141 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 1 Apr 2021 16:57:57 +0200 Subject: [PATCH 1/3] deepcopy --- torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 4ea08e36025..ce503274f86 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -259,7 +259,7 @@ def reset(self): for attr, default in self._defaults.items(): current_val = getattr(self, attr) if isinstance(default, Tensor): - setattr(self, attr, deepcopy(default).to(current_val.device)) + setattr(self, attr, default.detach().clone().to(current_val.device)) else: setattr(self, attr, deepcopy(default)) From 16751a62e5dac2e6281349d05386d3e53d180ed2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 1 Apr 2021 17:01:32 +0200 Subject: [PATCH 2/3] deepcopy --- torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index ce503274f86..6cc88ad9ce4 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -261,7 +261,7 @@ def reset(self): if isinstance(default, Tensor): setattr(self, attr, default.detach().clone().to(current_val.device)) else: - setattr(self, attr, deepcopy(default)) + setattr(self, attr, []) def clone(self): """ Make a copy of the metric """ From 4ac6a02a2fc8536af8345105b5b454ebd140c159 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 9 Apr 2021 10:54:41 +0200 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fbf0b1c90d..de3b0727c6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) - Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134)) - Updated FBeta arguments ([#111](https://github.com/PyTorchLightning/metrics/pull/111)) +- Changed `reset` method to use `detach.clone()` instead of `deepcopy` when resetting to default ([#163](https://github.com/PyTorchLightning/metrics/pull/163)) ### Deprecated