-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_multiwarp.py
138 lines (111 loc) · 5.58 KB
/
infer_multiwarp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import torch
import custom_transforms
import argparse
import matplotlib.pyplot as plt
import numpy as np
from multiwarp_dataloader import getValidationFocalstackLoader, getValidationStackedLFLoader, getValidationEpiLoader
from lfmodels import LFDispNet as DispNetS
from lfmodels import LFPoseNet as PoseNet
from lfmodels import EpiEncoder, RelativeEpiEncoder
from loss_functions import multiwarp_photometric_loss
from utils import load_config
import sys
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, type=str, help="Pkl file containing training configuration")
parser.add_argument("--seq", required=True, type=str, help="Name of sequence to perform inference on")
parser.add_argument("--use-latest-not-best", action="store_true",
help="Use the latest set of weights rather than the best")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@torch.no_grad()
def main():
poses = []
args = parser.parse_args()
config = load_config(args.config)
output_dir = os.path.join(config.save_path, "results", args.seq)
if args.use_latest_not_best:
config.posenet = os.path.join(config.save_path, "posenet_checkpoint.pth.tar")
config.dispnet = os.path.join(config.save_path, "dispnet_checkpoint.pth.tar")
output_dir = output_dir + "-latest"
else:
config.posenet = os.path.join(config.save_path, "posenet_best.pth.tar")
config.dispnet = os.path.join(config.save_path, "dispnet_best.pth.tar")
os.makedirs(output_dir)
os.makedirs(output_dir + "/depth")
os.makedirs(output_dir + "/warps")
os.makedirs(output_dir + "/diffs")
transform = custom_transforms.Compose([
custom_transforms.ArrayToTensor(),
custom_transforms.Normalize(mean=0.5, std=0.5)
])
disp_encoder=None
pose_encoder=None
if config.lfformat == 'focalstack':
dataset = getValidationFocalstackLoader(config, args.seq, transform, shuffle=False)
print("Loading images as focalstack")
elif config.lfformat == 'stack':
dataset = getValidationStackedLFLoader(config, args.seq, transform, shuffle=False)
print("Loading images as stack")
elif config.lfformat == 'epi':
dataset = getValidationEpiLoader(config, args.seq, transform, shuffle=False)
print("Loading images as tiled EPIs")
dispnet_input_channels = posenet_input_channels = dataset[0]['tgt_lf_formatted'].shape[0]
output_channels = len(config.cameras)
if config.lfformat == 'epi':
disp_encoder = EpiEncoder('vertical', config.tilesize).to(device)
pose_encoder = RelativeEpiEncoder('vertical', config.tilesize).to(device)
posenet_input_channels = 16 + len(config.cameras)
dispnet_input_channels = 16 + len(config.cameras)
print("Initialised disp and pose encoders")
print(f"[DispNet] Using {dispnet_input_channels} input channels, {output_channels} output channels")
print(f"[PoseNet] Using {posenet_input_channels} input channels")
disp_net = DispNetS(in_channels=dispnet_input_channels, out_channels=output_channels, encoder=disp_encoder).to(device)
weights = torch.load(config.dispnet)
disp_net.load_state_dict(weights['state_dict'])
disp_net.eval()
pose_net = PoseNet(in_channels=posenet_input_channels, nb_ref_imgs=2, encoder=pose_encoder).to(device)
weights = torch.load(config.posenet)
pose_net.load_state_dict(weights['state_dict'])
pose_net.eval()
print("Loaded dispnet and posenet")
for i, validData in enumerate(dataset):
print("{:03d}/{:03d}".format(i + 1, len(dataset)), end="\r")
tgt_formatted = validData['tgt_lf_formatted'].unsqueeze(0).to(device)
tgt = validData['tgt_lf'].unsqueeze(0).to(device)
ref_formatted = [r.unsqueeze(0).to(device) for r in validData['ref_lfs_formatted']]
ref = [r.unsqueeze(0).to(device) for r in validData['ref_lfs']]
intrinsics = torch.Tensor(validData['intrinsics']).unsqueeze(0).to(device)
metadata = validData['metadata']
metadata['cameras'] = torch.tensor(metadata['cameras']).unsqueeze(1).to(device)
if disp_net.hasEncoder():
tgt_encoded_d = disp_net.encode(tgt_formatted, tgt)
else:
tgt_encoded_d = tgt_formatted
if pose_net.hasEncoder():
tgt_encoded_p, ref_encoded_p = pose_net.encode(tgt_formatted, tgt, ref_formatted, ref)
else:
tgt_encoded_p = tgt_formatted
ref_encoded_p = ref_formatted
output = disp_net(tgt_encoded_d)
pose = pose_net(tgt_encoded_p, ref_encoded_p)
# print(output.shape)
# print(pose.shape)
# print(tgt.shape)
# print(ref[0].shape)
pe, warped, diff = multiwarp_photometric_loss(
tgt, ref, intrinsics, output, pose, metadata, config.rotation_mode, config.padding_mode
)
outdir = os.path.join(output_dir, "{:06d}.png".format(i))
plt.imsave(outdir, tgt.cpu().numpy()[0, 0, :, :], cmap='gray')
outdir = os.path.join(output_dir, "depth/{:06d}.png".format(i))
plt.imsave(outdir, output.cpu().numpy()[0, 0, :, :])
outdir = os.path.join(output_dir, "warps/{:06d}.png".format(i))
plt.imsave(outdir, warped[0][0].cpu().numpy()[0, 0, :, :], cmap='gray')
outdir = os.path.join(output_dir, "diffs/{:06d}.png".format(i))
plt.imsave(outdir, diff[0][0].cpu().numpy()[0, 0, :, :])
poses.append(pose[0, 0, :].cpu().numpy())
outdir = os.path.join(output_dir, "poses.npy")
np.save(outdir, poses)
print("\nok")
if __name__ == '__main__':
main()