Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module collision when loading more than one model with Torch Hub #243

Closed
carloalbertobono opened this issue Oct 8, 2021 · 5 comments
Closed
Labels
wontfix This will not be worked on

Comments

@carloalbertobono
Copy link

Hi, I'm having an issue with loading multiple specific models with torch.hub.load

Originally posted the description here and @glenn-jocher suggested to post it here too

I'm pasting the original issue below:

Hi, I think I have a similar issue to 2414
It prevents from loading more than one model using torch.hub, using specific models
If I'm not mistaken by reading the thread, loading the model with torch.hub shadows some module names, that then become unusable within torch.

I'm using torch '1.9.1+cu102' on a Ubuntu 20.04 machine and to reproduce I do:

import torch
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

that ends up in

~/.cache/torch/hub/ultralytics_yolov5_master/hubconf.py in _create(name, pretrained, channels, classes, autoshape, verbose, device)
     28     from pathlib import Path
     29 
---> 30     from models.yolo import Model
     31     from models.experimental import attempt_load
     32     from utils.general import check_requirements, set_logging

ModuleNotFoundError: No module named 'models.yolo'

reversing the load order obviously ends up with:

~/.cache/torch/hub/facebookresearch_detr_master/hubconf.py in <module>
      2 import torch
      3 
----> 4 from models.backbone import Backbone, Joiner
      5 from models.detr import DETR, PostProcess
      6 from models.position_encoding import PositionEmbeddingSine

ModuleNotFoundError: No module named 'models.backbone'

Is there some workaround which I'm not seeing?

Thank you very much, also for the awesome project itself
cb

@NicolasHug
Copy link
Member

Thanks for the detailed report @carloalbertobono , I can reproduce the issue. I'll look into this

@NicolasHug
Copy link
Member

As a temporary and ugly workaround, try this:

import torch
import sys

torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
sys.modules.pop('models')  # ¯\_(ツ)_/¯
torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

The error comes from the fact that the modules that were imported in the first hubconf file (the detr one) are still present in the imported module cache, even if the detr folder has been removed from sys.path. I'll try to see if there's a better way of handling all this. It's kinda funny but this code runs fine:

import torch

torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
import models  # there's no "models" dir or package but it still works because it's the "models" module from the detr repo and it's still in the modules cache
print(models)
# prints <module 'models' from '/Users/nicolashug/.cache/torch/hub/facebookresearch_detr_main/models/__init__.py'>

@carloalbertobono
Copy link
Author

Hi @NicolasHug , thank you very much for looking into this 🙏

I can confirm that the workaround works like a charm.
Plus, it will add an artistic comment (yours) to our codebase.

As far as I'm concerned, this is solved.
I guess I just need to avoid concurrently loading the models in the appliction, which I can patch anyway.

All the best
cb

@NicolasHug
Copy link
Member

@vmoens and I have tried to find a reasonable solution to fix this issue, but we were unable to find a solution that would be simple, non-magical, and fully torchscript-proof. My closest attempt is #247 (comment), but it still fails on some very specific case related to torchscript. Even if it didn't, I think the solution is a bit too magic to be reasonable.

Then again, this bug happens because of pre-existing torchhub magic, so it's no wonder that more magic is needed to fix it. Anyway - I'm afraid we'll have to mark this as a wont-fix issue. I provided a workaround above, and I'll make a PR to torch core to mention this in the "Known limitations" section of the torchhub docs page.

Thanks again for the report @carloalbertobono

@carloalbertobono
Copy link
Author

Then thank you very much @NicolasHug and @vmoens for the prompt help and of course for the work!
All the best
cb

NicolasHug added a commit to pytorch/pytorch that referenced this issue Dec 15, 2021
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Dec 16, 2021
…69970)

Summary:
Pull Request resolved: #69970

This is a follow up to pytorch/hub#243

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D33124060

Pulled By: NicolasHug

fbshipit-source-id: 298fe14b39a1aff3e0b029044c9a0db8bc82336a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants