-
Notifications
You must be signed in to change notification settings - Fork 7
/
inference.py
105 lines (78 loc) · 3.11 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Run DeepLab-ResNet on a given image.
This script computes a segmentation mask for a given image.
"""
from __future__ import print_function
import argparse
from datetime import datetime
import os
import sys
import time
from PIL import Image
import tensorflow as tf
import numpy as np
from deeplab_resnet import DeepLabResNetModel, ImageReader, decode_labels, dense_crf, prepare_label
SAVE_DIR = './output/'
IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="DeepLabLFOV Network Inference.")
parser.add_argument("img_path", type=str,
help="Path to the RGB image file.")
parser.add_argument("model_weights", type=str,
help="Path to the file with model weights.")
parser.add_argument("--save-dir", type=str, default=SAVE_DIR,
help="Where to save predicted mask.")
return parser.parse_args()
def load(saver, sess, ckpt_path):
'''Load trained weights.
Args:
saver: TensorFlow saver object.
sess: TensorFlow session.
ckpt_path: path to checkpoint file with parameters.
'''
saver.restore(sess, ckpt_path)
print("Restored model parameters from {}".format(ckpt_path))
def main():
"""Create the model and start the evaluation process."""
args = get_arguments()
# Prepare image.
img_orig = tf.image.decode_jpeg(tf.read_file(args.img_path), channels=3)
# Convert RGB to BGR.
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img_orig)
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
# Extract mean.
img -= IMG_MEAN
# Create network.
net = DeepLabResNetModel({'data': tf.expand_dims(img, dim=0)}, is_training=False)
# Which variables to load.
restore_var = tf.global_variables()
# Predictions.
raw_output = net.layers['fc1_voc12']
raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(img)[0:2,])
# CRF.
raw_output_up = tf.nn.softmax(raw_output_up)
raw_output_up = tf.py_func(dense_crf, [raw_output_up, tf.expand_dims(img_orig, dim=0)], tf.float32)
raw_output_up = tf.argmax(raw_output_up, dimension=3)
pred = tf.expand_dims(raw_output_up, dim=3)
# Set up TF session and initialize variables.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
# Load weights.
loader = tf.train.Saver(var_list=restore_var)
load(loader, sess, args.model_weights)
# Perform inference.
preds = sess.run(pred)
msk = decode_labels(preds)
im = Image.fromarray(msk[0])
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
im.save(args.save_dir + 'mask.png')
print('The output file has been saved to {}'.format(args.save_dir + 'mask.png'))
if __name__ == '__main__':
main()