-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_supervised.py
73 lines (59 loc) · 2.39 KB
/
infer_supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import torch
import custom_transforms
import argparse
import matplotlib.pyplot as plt
import numpy as np
from dataloader import SequenceFolder
from tqdm import tqdm
from lfmodels import LFPoseNet as PoseNet
from utils import load_config
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, type=str, help="Pkl file containing training configuration")
parser.add_argument("--seq", required=True, type=str, help="Name of sequence to perform inference on")
parser.add_argument("--use-latest-not-best", action="store_true", help="Use the latest set of weights rather than the best")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@torch.no_grad()
def main():
poses = []
args = parser.parse_args()
config = load_config(args.config)
output_dir = os.path.join(config.save_path, "results", args.seq)
if args.use_latest_not_best:
config.posenet = os.path.join(config.save_path, "posenet_checkpoint.pth.tar")
output_dir = output_dir + "-latest"
else:
config.posenet = os.path.join(config.save_path, "posenet_best.pth.tar")
os.makedirs(output_dir)
transform = custom_transforms.Compose([
custom_transforms.ArrayToTensor(),
custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = SequenceFolder(
config.data,
cameras=config.cameras,
gray=config.gray,
sequence_length=config.sequence_length,
shuffle=False,
train=False,
transform=transform,
sequence=args.seq,
)
input_channels = dataset[0][1].shape[0]
pose_net = PoseNet(in_channels=input_channels, nb_ref_imgs=2, output_exp=False).to(device)
weights = torch.load(config.posenet)
pose_net.load_state_dict(weights['state_dict'])
pose_net.eval()
for i, (tgt, tgt_lf, ref, ref_lf, k, kinv, pose_gt) in enumerate(dataset):
print("{:03d}/{:03d}".format(i+1, len(dataset)), end="\r")
tgt = tgt.unsqueeze(0).to(device)
ref = [r.unsqueeze(0).to(device) for r in ref]
tgt_lf = tgt_lf.unsqueeze(0).to(device)
ref_lf = [r.unsqueeze(0).to(device) for r in ref_lf]
exp, pose = pose_net(tgt_lf, ref_lf)
poses.append(pose[0,1,:].cpu().numpy())
outdir = os.path.join(output_dir, "poses.npy")
np.save(outdir, poses)
print("\nok")
if __name__ == '__main__':
main()