Skip to content

Commit f5c9277

Browse files
authored
Merge pull request #101 from HuangJunJie2017/master
Add a new demo for estimating optical flow on a single image pair
2 parents f1d475a + 691a763 commit f5c9277

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

run_a_pair.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import numpy as np
3+
import argparse
4+
5+
from Networks.FlowNet2 import FlowNet2 # the path is depended on where you create this module
6+
from frame_utils import read_gen # the path is depended on where you create this module
7+
8+
if __name__ == '__main__':
9+
# obtain the necessary args for construct the flownet framework
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).')
12+
parser.add_argument("--rgb_max", type=float, default=255.)
13+
args = parser.parse_args()
14+
15+
# initial a Net
16+
net = FlowNet2(args).cuda()
17+
# load the state_dict
18+
dict = torch.load("/home/hjj/PycharmProjects/flownet2_pytorch/FlowNet2_checkpoint.pth.tar")
19+
net.load_state_dict(dict["state_dict"])
20+
21+
# load the image pair, you can find this operation in dataset.py
22+
pim1 = read_gen("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img0.ppm")
23+
pim2 = read_gen("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img1.ppm")
24+
images = [pim1, pim2]
25+
images = np.array(images).transpose(3, 0, 1, 2)
26+
im = torch.from_numpy(images.astype(np.float32)).unsqueeze(0).cuda()
27+
28+
# process the image pair to obtian the flow
29+
result = net(im).squeeze()
30+
31+
32+
# save flow, I reference the code in scripts/run-flownet.py in flownet2-caffe project
33+
def writeFlow(name, flow):
34+
f = open(name, 'wb')
35+
f.write('PIEH'.encode('utf-8'))
36+
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
37+
flow = flow.astype(np.float32)
38+
flow.tofile(f)
39+
f.flush()
40+
f.close()
41+
42+
43+
data = result.data.cpu().numpy().transpose(1, 2, 0)
44+
writeFlow("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img.flo", data)

0 commit comments

Comments
 (0)