-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlive_simpsons.py
96 lines (61 loc) · 2.2 KB
/
live_simpsons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import cv2
import tensorflow as tf
import os
import sys
import matplotlib.pyplot as plt
import time
import pickle
MNIST_MODEL = "./trained_models/simpsons/"
CHARACTER_MAP = "./dumps/character_map.dump"
H, W = 128, 128
def load_characters():
f = open(CHARACTER_MAP, "rb", pickle.HIGHEST_PROTOCOL)
return pickle.load(f)
def load_trained_model():
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph(os.path.join(MNIST_MODEL, "model.ckpt.meta"))
saver.restore(sess, os.path.join(MNIST_MODEL, "model.ckpt"))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
out = graph.get_tensor_by_name("out/BiasAdd:0")
return sess, x, out
def softmax(_in):
return np.exp(_in) / np.sum(np.exp(_in))
def classify_character(img, character_map, sess, x, out):
img_copy = img.copy()
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)
img_h, img_w, _ = img_copy.shape
img_copy = cv2.resize(img_copy, (H, W))
graph_out = sess.run(out, feed_dict={x: np.reshape(img_copy, (1, H, W, 3))})
graph_out = softmax(np.squeeze(graph_out))
character = character_map.get(np.argmax(graph_out))
p = max(graph_out)
if p > 0.8:
cv2.putText(img, character.replace("_", " "), (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
return img
def run(vid, character_map, sess, x, out):
FPS = 24
cap = cv2.VideoCapture(vid)
cv2.namedWindow("Frame", cv2.WND_PROP_FULLSCREEN)
cv2.setWindowProperty("Frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
now = time.time()
while(True):
ret, frame = cap.read()
h, w, c = frame.shape
img = classify_character(frame, character_map, sess, x, out)
cv2.imshow("Frame", img)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
interval = time.time() - now
if interval < 1/FPS:
time.sleep(1/FPS - interval)
now = time.time()
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
vid = sys.argv[1]
sess, x, out = load_trained_model()
character_map = load_characters()
run(vid, character_map, sess, x, out)