-
Notifications
You must be signed in to change notification settings - Fork 18
/
TEMPLATE.py
69 lines (48 loc) · 2.21 KB
/
TEMPLATE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from pathlib import Path
import gdown, py3_wget
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
from matching.utils import to_numpy, resize_to_divisible, add_to_path
add_to_path(THIRD_PARTY_DIR.joinpath('path/to/submodule'))
from submodule import model, other_components
class NewMatcher(BaseMatcher):
weights_src = "url-to-weights-hosted-online"
model_path = WEIGHTS_DIR.joinpath("my_weights.ckpt")
divisible_size = 32
def __init__(self, device="cpu", *args, **kwargs):
super().__init__(device, **kwargs)
self.download_weights()
self.matcher = model()
self.matcher.load_state_dict(torch.load(self.model_path, map_location=torch.device("cpu"))["state_dict"])
def download_weights(self):
# check if weights exist, otherwise download them
if not Path(NewMatcher.model_path).is_file():
print("Downloading model... (takes a while)")
# if a google drive link
gdown.download(
NewMatcher.weights_src,
output=str(NewMatcher.model_path),
fuzzy=True,
)
# else
py3_wget.download_file(NewMatcher.model_path, NewMatcher.model_path)
def preprocess(self, img):
_, h, w = img.shape
orig_shape = h, w
# if requires divisibility
img = resize_to_divisible(img, self.divisible_size)
# if model requires a "batch"
img = img.unsqueeze(0)
return img, orig_shape
def _forward(self, img0, img1):
img0, img0_orig_shape = self.preprocess(img0)
img1, img1_orig_shape = self.preprocess(img1)
batch = {"image0": img0, "image1": img1}
# run through model to get outputs
output = self.matcher(batch)
# postprocess model output to get kpts, desc, etc
# if we had to resize the img to divisible, then rescale the kpts back to input img size
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:]
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0)
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)
return mkpts0, mkpts1, keypoints_0, keypoints_1, desc0, desc1