-
Notifications
You must be signed in to change notification settings - Fork 17
/
predict.py
89 lines (71 loc) · 2.14 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 26 21:10:16 2017
@author: dhaval
"""
import sys
import argparse
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.models import load_model
from keras.applications.inception_v3 import preprocess_input
target_size = (299, 299) #fixed size for InceptionV3 architecture
def predict(model, img, target_size):
"""Run model prediction on image
Args:
model: keras model
img: PIL format image
target_size: (w,h) tuple
Returns:
list of predicted labels and their probabilities
"""
if img.size != target_size:
img = img.resize(target_size)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
return preds[0]
def plot_preds(image, preds):
"""Displays image and the top-n predicted probabilities in a bar graph
Args:
image: PIL image
preds: list of predicted labels and their probabilities
"""
"""# For Spyder
plt.imshow(image)
plt.axis('off')"""
plt.figure()
labels = ("cat", "dog")
plt.barh([0, 1], preds, alpha=0.5)
plt.yticks([0, 1], labels)
plt.xlabel('Probability')
plt.xlim(0,1.01)
plt.tight_layout()
plt.savefig('out.png')
if __name__=="__main__":
a = argparse.ArgumentParser()
a.add_argument("--image", help="path to image")
a.add_argument("--image_url", help="url to image")
a.add_argument("--model")
args = a.parse_args()
if args.image is None and args.image_url is None:
a.print_help()
sys.exit(1)
model = load_model(args.model)
if args.image is not None:
img = Image.open(args.image)
preds = predict(model, img, target_size)
plot_preds(img, preds)
if args.image_url is not None:
response = requests.get(args.image_url)
img = Image.open(BytesIO(response.content))
preds = predict(model, img, target_size)
plot_preds(img, preds)