1+ import torch
2+ import numpy as np
3+ import argparse
4+
5+ from Networks .FlowNet2 import FlowNet2 # the path is depended on where you create this module
6+ from frame_utils import read_gen # the path is depended on where you create this module
7+
8+ if __name__ == '__main__' :
9+ # obtain the necessary args for construct the flownet framework
10+ parser = argparse .ArgumentParser ()
11+ parser .add_argument ('--fp16' , action = 'store_true' , help = 'Run model in pseudo-fp16 mode (fp16 storage fp32 math).' )
12+ parser .add_argument ("--rgb_max" , type = float , default = 255. )
13+ args = parser .parse_args ()
14+
15+ # initial a Net
16+ net = FlowNet2 (args ).cuda ()
17+ # load the state_dict
18+ dict = torch .load ("/home/hjj/PycharmProjects/flownet2_pytorch/FlowNet2_checkpoint.pth.tar" )
19+ net .load_state_dict (dict ["state_dict" ])
20+
21+ # load the image pair, you can find this operation in dataset.py
22+ pim1 = read_gen ("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img0.ppm" )
23+ pim2 = read_gen ("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img1.ppm" )
24+ images = [pim1 , pim2 ]
25+ images = np .array (images ).transpose (3 , 0 , 1 , 2 )
26+ im = torch .from_numpy (images .astype (np .float32 )).unsqueeze (0 ).cuda ()
27+
28+ # process the image pair to obtian the flow
29+ result = net (im ).squeeze ()
30+
31+
32+ # save flow, I reference the code in scripts/run-flownet.py in flownet2-caffe project
33+ def writeFlow (name , flow ):
34+ f = open (name , 'wb' )
35+ f .write ('PIEH' .encode ('utf-8' ))
36+ np .array ([flow .shape [1 ], flow .shape [0 ]], dtype = np .int32 ).tofile (f )
37+ flow = flow .astype (np .float32 )
38+ flow .tofile (f )
39+ f .flush ()
40+ f .close ()
41+
42+
43+ data = result .data .cpu ().numpy ().transpose (1 , 2 , 0 )
44+ writeFlow ("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img.flo" , data )
0 commit comments