-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference.py
69 lines (49 loc) · 1.77 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
def read_image(image):
return mpimg.imread(image)
def format_image(image):
return tf.image.resize(image[tf.newaxis, ...], [224, 224]) / 255.0
def get_category(img):
"""Write a Function to Predict the Class Name
Args:
img [jpg]: image file
Returns:
[str]: Prediction
"""
path = 'static/model/'
tflite_model_file = 'converted_model.tflite'
# Load TFLite model and allocate tensors.
with open(path + tflite_model_file, 'rb') as fid:
tflite_model = fid.read()
# Interpreter interface for TensorFlow Lite Models.
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
# Gets model input and output details.
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
input_img = read_image(img)
format_img = format_image(input_img)
# Sets the value of the input tensor
interpreter.set_tensor(input_index, format_img)
# Invoke the interpreter.
interpreter.invoke()
predictions_array = interpreter.get_tensor(output_index)
predicted_label = np.argmax(predictions_array)
class_names = ['rock', 'paper', 'scissors']
return class_names[predicted_label]
def plot_category(img, current_time):
"""Plot the input image
Args:
img [jpg]: image file
"""
read_img = mpimg.imread(img)
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(ROOT_DIR + f'/static/images/output_{current_time}.png')
print(file_path)
if os.path.exists(file_path):
os.remove(file_path)
plt.imsave(file_path, read_img)