diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index fe78869911fda..c9fb053c7d336 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -218,3 +218,13 @@ def reset(self): setattr(self, attr, deepcopy(default).to(current_val.device)) else: setattr(self, attr, deepcopy(default)) + + def __getstate__(self): + # ignore update and compute functions for pickling + return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} + + def __setstate__(self, state): + # manually restore update and compute functions for pickling + self.__dict__.update(state) + self.update = self._wrap_update(self.update) + self.compute = self._wrap_compute(self.compute) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 62b9384b219d0..366c873127eab 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -4,6 +4,9 @@ import os import numpy as np +import pickle +import cloudpickle + torch.manual_seed(42) @@ -106,3 +109,30 @@ def compute(self): # called without update, should return cached value a._computed = 5 assert a.compute() == 5 + + +class ToPickle(Dummy): + def update(self, x): + self.x += x + + def compute(self): + return self.x + + +def test_pickle(tmpdir): + # doesn't tests for DDP + a = ToPickle() + a.update(1) + + metric_pickled = pickle.dumps(a) + metric_loaded = pickle.loads(metric_pickled) + + assert metric_loaded.compute() == 1 + + metric_loaded.update(5) + assert metric_loaded.compute() == 5 + + metric_pickled = cloudpickle.dumps(a) + metric_loaded = cloudpickle.loads(metric_pickled) + + assert metric_loaded.compute() == 1 diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 88997006b0014..3ba87f456ad43 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -3,6 +3,7 @@ import os import sys import pytest +import pickle NUM_PROCESSES = 2 NUM_BATCHES = 10 @@ -19,6 +20,10 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args) + # verify metrics work after being loaded from pickled state + pickled_metric = pickle.dumps(metric) + metric = pickle.loads(pickled_metric) + # Only use ddp if world size if worldsize > 1: setup_ddp(rank, worldsize)