-
Notifications
You must be signed in to change notification settings - Fork 29
/
demo_match.py
48 lines (38 loc) · 1.69 KB
/
demo_match.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from dkm.utils.utils import tensor_to_pil
from dkm import DKMv3_outdoor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
args, _ = parser.parse_known_args()
im1_path = args.im_A_path
im2_path = args.im_B_path
save_path = args.save_path
# Create model
dkm_model = DKMv3_outdoor(device=device)
H, W = 864, 1152
im1 = Image.open(im1_path).resize((W, H))
im2 = Image.open(im2_path).resize((W, H))
# Match
warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
# Sampling not needed, but can be done with model.sample(warp, certainty)
dkm_model.sample(warp, certainty)
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
im2_transfer_rgb = F.grid_sample(
x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
)[0]
im1_transfer_rgb = F.grid_sample(
x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
)[0]
warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
white_im = torch.ones((H,2*W),device=device)
vis_im = certainty * warp_im + (1 - certainty) * white_im
tensor_to_pil(vis_im, unnormalize=False).save(save_path)