-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathindex.py
48 lines (35 loc) · 1.69 KB
/
index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
from keras import backend as K
import tensorflow as tf
from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib
import numpy as np
def export_model_for_mobile(model_name, input_node_name, output_node_name):
tf.train.write_graph(K.get_session().graph_def, 'out', \
model_name + '_graph.pbtxt')
tf.train.Saver().save(K.get_session(), 'out/' + model_name + '.chkp')
freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, \
False, 'out/' + model_name + '.chkp', output_node_name, \
"save/restore_all", "save/Const:0", \
'out/frozen_' + model_name + '.pb', True, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + model_name + '.pb', "rb") as f:
input_graph_def.ParseFromString(f.read())
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, [input_node_name], [output_node_name],
tf.float32.as_datatype_enum)
with tf.gfile.FastGFile('out/tensorflow_lite_' + model_name + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
def main():
X = np.array([[0,0],[0,1],[1,0],[1,1]])
Y = np.array([[0],[1],[1],[0]])
model = Sequential()
model.add(Dense(8, input_dim=2, activation='tanh'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer=SGD(lr=0.1))
model.fit(X, Y, batch_size=1, nb_epoch=1000)
export_model_for_mobile('xor_nn', "dense_1_input", "dense_2/Sigmoid")
model.summary()
if __name__=='__main__':
main()