From a4820f81f8fcf09fdb4b47eaf167e559414e3587 Mon Sep 17 00:00:00 2001 From: Aymeric DUJARDIN Date: Fri, 21 Feb 2020 19:05:33 +0100 Subject: [PATCH 1/2] Adding auto download pre-trained model from gdrive --- eval.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/eval.py b/eval.py index 547bc0aae..073fc7dfb 100644 --- a/eval.py +++ b/eval.py @@ -28,6 +28,7 @@ import matplotlib.pyplot as plt import cv2 +import requests def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): @@ -1045,6 +1046,63 @@ def print_maps(all_maps): print() +def download_file_from_google_drive(id, destination): + #https://stackoverflow.com/a/39225039/7036639 + def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + def save_response_content(response, destination): + CHUNK_SIZE = 32768 + with open(destination, "wb") as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + URL = "https://docs.google.com/uc?export=download" + session = requests.Session() + response = session.get(URL, params = { 'id' : id }, stream = True) + token = get_confirm_token(response) + if token: + params = { 'id' : id, 'confirm' : token } + response = session.get(URL, params = params, stream = True) + save_response_content(response, destination) + + +def check_model(model_path): + model_path = str(model_path) + print(model_path) + + model_url_dict = {"yolact_resnet50_54_800000.pth": "1yp7ZbbDwvMiFJEq4ptVKTYTI2VeRDXl0", + "yolact_darknet53_54_800000.pth": "1dukLrTzZQEuhzitGkHaGjphlmRJOjVnP", + "yolact_base_54_800000.pth": "1UYy3dMapbH1BnmtZU4WH1zbYgOzzHHf_", + "yolact_im700_54_800000.pth": "1lE4Lz5p25teiXV-6HdTiOJSnS7u7GBzg", + "yolact_plus_resnet50_54_800000.pth": "1ZPu1YR2UzGHQD0o1rEqy-j5bmEm3lbyP", + "yolact_plus_base_54_800000.pth": "15id0Qq5eqRbkD-N3ZjDZXdCvRyIaHpFB" + } + + if not os.path.isfile(model_path): + print("Model not found, trying to download it...") + url = '' + + # Create folder if missing + folder=os.path.dirname(model_path) + if not os.path.exists(folder): + os.makedirs(folder) + + # Look for the model URL from the known models + for model_candidate in model_url_dict: + if model_candidate in model_path: + url = model_url_dict[model_candidate] + break + if url == '': + print("No candidate for download found") + exit(1) + output = model_path + download_file_from_google_drive(url, output) + if __name__ == '__main__': parse_args() @@ -1058,6 +1116,7 @@ def print_maps(all_maps): args.trained_model = SavePath.get_latest('weights/', cfg.name) if args.config is None: + check_model(args.trained_model) model_path = SavePath.from_str(args.trained_model) # TODO: Bad practice? Probably want to do a name lookup instead. args.config = model_path.model_name + '_config' From 88673a6e4b92d83159e0e21b2afcbba151ecf359 Mon Sep 17 00:00:00 2001 From: Aymeric Dujardin Date: Sun, 7 Jun 2020 14:50:01 +0200 Subject: [PATCH 2/2] Moving Gdrive download to utils --- eval.py | 28 +--------------------------- utils/functions.py | 28 +++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/eval.py b/eval.py index 073fc7dfb..e6e0fafcc 100644 --- a/eval.py +++ b/eval.py @@ -4,7 +4,7 @@ from utils.functions import MovingAverage, ProgressBar from layers.box_utils import jaccard, center_size, mask_iou from utils import timer -from utils.functions import SavePath +from utils.functions import SavePath, download_file_from_google_drive from layers.output_utils import postprocess, undo_image_transformation import pycocotools @@ -28,7 +28,6 @@ import matplotlib.pyplot as plt import cv2 -import requests def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): @@ -1046,31 +1045,6 @@ def print_maps(all_maps): print() -def download_file_from_google_drive(id, destination): - #https://stackoverflow.com/a/39225039/7036639 - def get_confirm_token(response): - for key, value in response.cookies.items(): - if key.startswith('download_warning'): - return value - return None - - def save_response_content(response, destination): - CHUNK_SIZE = 32768 - with open(destination, "wb") as f: - for chunk in response.iter_content(CHUNK_SIZE): - if chunk: # filter out keep-alive new chunks - f.write(chunk) - - URL = "https://docs.google.com/uc?export=download" - session = requests.Session() - response = session.get(URL, params = { 'id' : id }, stream = True) - token = get_confirm_token(response) - if token: - params = { 'id' : id, 'confirm' : token } - response = session.get(URL, params = params, stream = True) - save_response_content(response, destination) - - def check_model(model_path): model_path = str(model_path) print(model_path) diff --git a/utils/functions.py b/utils/functions.py index 3b7a4e45a..6c41b98e2 100644 --- a/utils/functions.py +++ b/utils/functions.py @@ -5,6 +5,7 @@ from collections import deque from pathlib import Path from layers.interpolate import InterpolateModule +import requests class MovingAverage(): """ Keeps an average window of the specified number of items. """ @@ -210,4 +211,29 @@ def make_layer(layer_cfg): if not include_last_relu: net = net[:-1] - return nn.Sequential(*(net)), in_channels \ No newline at end of file + return nn.Sequential(*(net)), in_channels + + +def download_file_from_google_drive(id, destination): + #https://stackoverflow.com/a/39225039/7036639 + def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + def save_response_content(response, destination): + CHUNK_SIZE = 32768 + with open(destination, "wb") as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + URL = "https://docs.google.com/uc?export=download" + session = requests.Session() + response = session.get(URL, params = { 'id' : id }, stream = True) + token = get_confirm_token(response) + if token: + params = { 'id' : id, 'confirm' : token } + response = session.get(URL, params = params, stream = True) + save_response_content(response, destination) \ No newline at end of file