Skip to content

Commit

Permalink
Avoid torchscript export for Metric forward (#4428)
Browse files Browse the repository at this point in the history
* Update metric.py

* add test

* Update CHANGELOG.md

* Update test_metric_lightning.py

* Update test_metric_lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
(cherry picked from commit 5d08559)
  • Loading branch information
ananthsub authored and Borda committed Nov 4, 2020
1 parent 38f4a83 commit e8a6e02
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))

- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))

### Changed

Expand Down Expand Up @@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))

- Fixed TorchScript export when module includes Metrics ([#4428](https://github.com/PyTorchLightning/pytorch-lightning/pull/4428))

## [1.0.4] - 2020-10-27

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def add_state(
self._defaults[name] = deepcopy(default)
self._reductions[name] = dist_reduce_fx

@torch.jit.unused
def forward(self, *args, **kwargs):
"""
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
Expand Down
41 changes: 40 additions & 1 deletion tests/metrics/test_metric_lightning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import os

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric
from tests.base.boring_model import BoringModel
Expand Down Expand Up @@ -78,3 +79,41 @@ def training_step(self, batch, batch_idx):

logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)


def test_scriptable(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
# the metric is not used in the module's `forward`
# so the module should be exportable to TorchScript
self.metric = SumMetric()
self.sum = 0.0

def training_step(self, batch, batch_idx):
x = batch
self.metric(x.sum())
self.sum += x.sum()
self.log("sum", self.metric, on_epoch=True, on_step=False)
return self.step(x)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
logger=False,
checkpoint_callback=False,
)
trainer.fit(model)
rand_input = torch.randn(10, 32)

script_model = model.to_torchscript()

# test that we can still do inference
output = model(rand_input)
script_output = script_model(rand_input)
assert torch.allclose(output, script_output)

0 comments on commit e8a6e02

Please sign in to comment.