Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

googlenet fires assert for not being an nn.Module with torch.jit.trace() #2161

Closed
mneilly-et opened this issue Apr 29, 2020 · 2 comments
Closed

Comments

@mneilly-et
Copy link

🐛 Bug

Trying to trace the pretrained googlenet model with torch.jit.trace() fails with an assertion:

Traceback (most recent call last):
File "./gn.py", line 14, in
m = torch.jit.trace(model, input)
File "/eng/mneilly/.virtualenvs/machine_learning/lib/python3.6/site-packages/torch/jit/init.py", line 875, in trace
check_tolerance, _force_outplace, _module_class)
File "/eng/mneilly/.virtualenvs/machine_learning/lib/python3.6/site-packages/torch/jit/init.py", line 1021, in trace_module
module = make_module(mod, _module_class, _compilation_unit)
File "/eng/mneilly/.virtualenvs/machine_learning/lib/python3.6/site-packages/torch/jit/init.py", line 720, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/eng/mneilly/.virtualenvs/machine_learning/lib/python3.6/site-packages/torch/jit/init.py", line 1884, in init
tmp_module._modules[name] = make_module(submodule, TracedModule, _compilation_unit=None)
File "/eng/mneilly/.virtualenvs/machine_learning/lib/python3.6/site-packages/torch/jit/init.py", line 720, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/eng/mneilly/.virtualenvs/machine_learning/lib/python3.6/site-packages/torch/jit/init.py", line 1845, in init
assert(isinstance(orig, torch.nn.Module))
AssertionError

To Reproduce

import torch
import torchvision
from torchvision import transforms
import numpy as np
from PIL import Image

pixels = np.random.rand(224, 224, 3) * 255
img = Image.fromarray(pixels.astype('uint8')).convert('RGB')
input = transforms.ToTensor()(img).unsqueeze_(0)
        
model = torchvision.models.googlenet(pretrained=True).eval()
m = torch.jit.trace(model, input)

Expected behavior

Model traces successfully.

Environment

PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: CentOS Linux release 7.6.1810 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-39)
CMake version: version 3.15.3

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 440.64.00
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.18.2
[pip3] torch==1.5.0
[pip3] torch-glow==0.0.0
[pip3] torchvision==0.6.0
[conda] Could not collect

@mthrok
Copy link
Contributor

mthrok commented Apr 29, 2020

@fmassa I can reproduce this.

@fmassa
Copy link
Member

fmassa commented Apr 30, 2020

Hi,

I recommend instead of tracing to use torch.jit.script, which is more generic as it supports control flow and more input types.
In torchvision, torch.jit.script is supported for all models, but not torch.jit.trace.

So here is what I would do:

model = torchvision.models.googlenet(pretrained=True).eval()
m = torch.jit.script(model)

As such, I'm closing this issue but let me know if you have further questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants