Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto download pre-trained model from gdrive #352

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from utils.functions import MovingAverage, ProgressBar
from layers.box_utils import jaccard, center_size, mask_iou
from utils import timer
from utils.functions import SavePath
from utils.functions import SavePath, download_file_from_google_drive
from layers.output_utils import postprocess, undo_image_transformation
import pycocotools

Expand Down Expand Up @@ -1045,6 +1045,38 @@ def print_maps(all_maps):
print()


def check_model(model_path):
model_path = str(model_path)
print(model_path)

model_url_dict = {"yolact_resnet50_54_800000.pth": "1yp7ZbbDwvMiFJEq4ptVKTYTI2VeRDXl0",
"yolact_darknet53_54_800000.pth": "1dukLrTzZQEuhzitGkHaGjphlmRJOjVnP",
"yolact_base_54_800000.pth": "1UYy3dMapbH1BnmtZU4WH1zbYgOzzHHf_",
"yolact_im700_54_800000.pth": "1lE4Lz5p25teiXV-6HdTiOJSnS7u7GBzg",
"yolact_plus_resnet50_54_800000.pth": "1ZPu1YR2UzGHQD0o1rEqy-j5bmEm3lbyP",
"yolact_plus_base_54_800000.pth": "15id0Qq5eqRbkD-N3ZjDZXdCvRyIaHpFB"
}

if not os.path.isfile(model_path):
print("Model not found, trying to download it...")
url = ''

# Create folder if missing
folder=os.path.dirname(model_path)
if not os.path.exists(folder):
os.makedirs(folder)

# Look for the model URL from the known models
for model_candidate in model_url_dict:
if model_candidate in model_path:
url = model_url_dict[model_candidate]
break
if url == '':
print("No candidate for download found")
exit(1)
output = model_path
download_file_from_google_drive(url, output)


if __name__ == '__main__':
parse_args()
Expand All @@ -1058,6 +1090,7 @@ def print_maps(all_maps):
args.trained_model = SavePath.get_latest('weights/', cfg.name)

if args.config is None:
check_model(args.trained_model)
model_path = SavePath.from_str(args.trained_model)
# TODO: Bad practice? Probably want to do a name lookup instead.
args.config = model_path.model_name + '_config'
Expand Down
28 changes: 27 additions & 1 deletion utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import deque
from pathlib import Path
from layers.interpolate import InterpolateModule
import requests

class MovingAverage():
""" Keeps an average window of the specified number of items. """
Expand Down Expand Up @@ -210,4 +211,29 @@ def make_layer(layer_cfg):
if not include_last_relu:
net = net[:-1]

return nn.Sequential(*(net)), in_channels
return nn.Sequential(*(net)), in_channels


def download_file_from_google_drive(id, destination):
#https://stackoverflow.com/a/39225039/7036639
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None

def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)

URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params = { 'id' : id }, stream = True)
token = get_confirm_token(response)
if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params = params, stream = True)
save_response_content(response, destination)