Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric ddp bugfix #4482

Merged
merged 19 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
It is highly recommended to re-initialize the metric per mode as
shown in the examples above.

.. note::

Metric states will as default add their internal state to the models ``state_dict``.
To change this after initializing the metric the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

*********************
Implementing a Metric
*********************
Expand Down
43 changes: 34 additions & 9 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def __init__(
self._forward_cache = None

# initialize state
self._reductions = {}
self._defaults = {}
self._persistent = {}
self._reductions = {}

def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
Expand Down Expand Up @@ -138,16 +139,10 @@ def add_state(
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
)

if isinstance(default, torch.Tensor):
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
# persistent keyword is only supported in torch >= 1.6.0
self.register_buffer(name, default, persistent=persistent)
else:
self.register_buffer(name, default)
else:
setattr(self, name, default)
setattr(self, name, default)

self._defaults[name] = deepcopy(default)
self._persistent[name] = persistent
self._reductions[name] = dist_reduce_fx

@torch.jit.unused
Expand Down Expand Up @@ -265,3 +260,33 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)

def _apply(self, fn):
""" Overwrite _apply function such that we can also move metric states
to the correct divice when `.to`, `.cuda` ect methods are called
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""
self = super()._apply(fn)
# Also apply fn to metric states
for key in self._defaults.keys():
current_val = getattr(self, key, None)
if current_val is not None and isinstance(current_val, torch.Tensor):
setattr(self, key, fn(current_val))
else:
setattr(self, key, [fn(cur_v) for cur_v in current_val])
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return self

def persistent(self, mode: bool = True):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
for key in self._persistent.keys():
self._persistent[key] = mode

def state_dict(self, *args, **kwargs):
# Register metric states to be part of the state_dict
state_dict = super().state_dict()
for key in self._defaults.keys():
if self._persistent[key]:
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict
60 changes: 54 additions & 6 deletions tests/metrics/test_metric_lightning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pytest

import torch

from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric
from tests.base.boring_model import BoringModel
import tests.base.develop_utils as tutils


class SumMetric(Metric):
Expand Down Expand Up @@ -54,15 +56,19 @@ def test_metric_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = SumMetric()
self.metric_step = SumMetric()
self.metric_epoch = SumMetric()
self.sum = 0.0

def training_step(self, batch, batch_idx):
x = batch
self.metric(x.sum())
self.metric_step(x.sum())
self.sum += x.sum()
self.log("sum", self.metric, on_epoch=True, on_step=False)
return self.step(x)
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
return {'loss': self.step(x), 'data': x}

def training_epoch_end(self, outs):
self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum()))

model = TestModel()
model.val_dataloader = None
Expand All @@ -78,7 +84,49 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)


class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = SumMetric()
self.p = torch.nn.Linear(1,1) # fake params

def training_step(self, batch, batch_idx):
val = self.metric(batch[0].sum())
self.log("sum", self.metric, on_step=False, on_epoch=True)
return self.p(val.view(1,1))

def configure_optimizers(self):
return None


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_metric_lightning_ddp(tmpdir):
tutils.set_random_master_port()

# Dummy dataset, where sum is known
data = torch.arange(10)[:,None].float()
dataset = torch.utils.data.TensorDataset(data)
dataloader = torch.utils.data.DataLoader(dataset)

model = TestModel()
trainer = Trainer(
gpus=2,
max_epochs=1,
log_every_n_steps=1,
accelerator='ddp_spawn',
progress_bar_refresh_rate=0,
replace_sampler_ddp=False
)
trainer.fit(model, dataloader)

logged = trainer.logged_metrics

assert torch.tensor(logged["sum"]) == data.sum(), \
"Metrics did not accumulate correctly in ddp mode"


def test_scriptable(tmpdir):
Expand Down