Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't TorchScript LightningModule when using Metric #4416

Closed
hudeven opened this issue Oct 28, 2020 · 8 comments · Fixed by #4428
Closed

Can't TorchScript LightningModule when using Metric #4416

hudeven opened this issue Oct 28, 2020 · 8 comments · Fixed by #4428
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@hudeven
Copy link

hudeven commented Oct 28, 2020

🐛 Bug

Please reproduce using the BoringModel and post here

able to reproduce it in https://colab.research.google.com/drive/1MscNHxIc_LIbZxALHbZOAkooNu0TzVly?usp=sharing

To Reproduce

Expected behavior

Able to torchscript a Lightning moduel no matter Metric is used or not

It seems hard to make Metric torchscriptable as *args and **kwargs are useful in Python but not supported in torchscript.
As Metric is not needed for inference, I think it should be excluded when calling LightningModule.to_torchscript().

Environment

Note: Bugs with code are solved faster ! Colab Notebook should be made public !

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0+cu101
    • pytorch-lightning: 0.10.0
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

@hudeven hudeven added bug Something isn't working help wanted Open to be worked on labels Oct 28, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@ananthsub
Copy link
Contributor

@hudeven as a workaround can you override to_torchscript in your LightningModule?

@hudeven
Copy link
Author

hudeven commented Oct 29, 2020

@hudeven as a workaround can you override to_torchscript in your LightningModule?

yeah, it works by overriding to_torchscript() and deleting the metric attributes there

@snisarg
Copy link

snisarg commented Oct 29, 2020

A controversial suggestion but could we use another method name instead of forward? I understand using the class name directly is convenient.

Also curious about why the metrics class is an nn.Module? Is it to avoid the pains of syncing across distributed envs instead of using torch.distributed?

@NumesSanguis
Copy link
Contributor

NumesSanguis commented Oct 29, 2020

Having Metrics somehow unique would also be helpful when saving / loading .ckpt. While this is intended behavior (it's a nn.Module), I also run into trouble: #4361

For inference Metrics are not necessary, but if you use them, you need to set strict=False when loading a .ckpt. This creates a risk that actual problems will be ignored (the ones you want loading to fail on).

@NumesSanguis
Copy link
Contributor

NumesSanguis commented Oct 29, 2020

@hudeven I don't know if you need to use TorchScript in combination with method='script', but with method='trace your code does work.

Add this to your script:

def training_step(self, batch, batch_idx):
    # use first batch to create an example input
    if self.example_input_array is None:
        # we only need 1 samples, not a whole batch, but keep the batch dimension
        self.example_input_array = batch[0, :].unsqueeze(dim=0)

And change model.to_torchscript(method='trace') # method='script'.

This does not solve the actual issue at hand, but can be a workaround for some here.

@teddykoker
Copy link
Contributor

Also curious about why the metrics class is an nn.Module? Is it to avoid the pains of syncing across distributed envs instead of using torch.distributed?

@snisarg, we are using an nn.Module for the metrics so that the state of the metric can be passed .to() different devices, which is necessary if you want to use them in both Lightning or just plain PyTorch. We are still using torch.distributed to sync the metric states across GPUs.

@hudeven
Copy link
Author

hudeven commented Nov 2, 2020

@NumesSanguis thanks for the workaround! I intend to use 'script'. This issue is fixed in #4428

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants