-
Notifications
You must be signed in to change notification settings - Fork 0
/
read_graph.py
192 lines (178 loc) · 6.92 KB
/
read_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import tensorflow as tf
import os
import sys
from utils import sort_ops ,_get_ops_in_path
from GraphBuilder import GraphBuilder
from MergeLayers import merge_layers
from NodeObj import OPNode,TSNode
def remove_identity_const(sorted_ops):
ops=[]
for op in sorted_ops:
if op.type=='Const':
continue
elif op.type=='Identity':
output = op.outputs[0]
output.identity_from=op.inputs[0]
sorted_ops.remove(op)
# print('--->',output.identity_from)
else:
continue
return sorted_ops
def read_graph_from_pb(tf_model_path ,input_names,output_name):
with open(tf_model_path, 'rb') as f:
serialized = f.read()
tf.reset_default_graph()
gdef = tf.GraphDef()
gdef.ParseFromString(serialized)
with tf.Graph().as_default() as g:
tf.import_graph_def(gdef, name='')
with tf.Session(graph=g) as sess:
OPS=get_ops_from_pb(g,input_names,output_name)
return OPS
def remove_ops_before_inputs(inputs,ops):
tensor_queue=inputs.copy()
visited_ts=set()
invalid_ops=set()
while len(tensor_queue)>0:
ts = tensor_queue.pop(0)
if not ts.op in invalid_ops:
invalid_ops.add(ts.op)
tensor_queue=tensor_queue+[inp for inp in ts.op.inputs if not inp in visited_ts]
visited_ts.add(ts)
ops = [op for op in ops if not op in invalid_ops]
ops = get_connected_ops(ops,inputs)
return ops
def get_connected_ops(ops_set,start_tensors):
visited_ts = set()
visited_ops=set()
ts_queue=start_tensors
while len(ts_queue)>0:
ts = ts_queue.pop(0)
if ts.op in ops_set:
visited_ops.add(ts.op)
ts_queue=ts_queue+[input for input in ts.op.inputs if not input in visited_ts]
for op in ts.consumers():
if op in ops_set:
visited_ops.add(op)
ts_queue=ts_queue+[output for output in op.outputs if not output in visited_ts]
visited_ts.add(ts)
ops = [op for op in ops_set if op in visited_ops]
return ops
def ops_to_OPNodes(ops,inputs):
ops_map = dict()
ts_set=set()
ts_map=dict()
for op in ops:
op_node = OPNode(op)
ops_map[op]= op_node
ts_set |=set([ts for ts in op.inputs])
ts_set |=set([ts for ts in op.outputs])
for ts in ts_set:
ts_map[ts]=TSNode(ts,None)
for op,op_node in ops_map.items():
inps=[]
for inp in op.inputs:#修改节点输入
inp = ts_map[inp]
inps.append(inp)
op_node.inputs=inps
outputs=[]
for output in op.outputs:#修改节点输出
output = ts_map[output]
outputs.append(output)
op_node.outputs=outputs
for ts,ts_node in ts_map.items():
consumers=[]
for op in ts.consumers():
if op in ops:
consumers.append(ops_map[op])
ts_node.next_ops=consumers
ts_node.op = ops_map.get(ts.op,None)
if ts_node.op==None:
print('---->',ts_node.name)
print(inputs)
#将inputs用placeholder替换
replace_input=dict()
for input in inputs:#将input映射placeholder
# if not input.op.type=='Placeholder':
input_shape = input.get_shape()
if input_shape==None:
input_shape=[None,None,None,None]
ph = tf.placeholder(input.dtype,input_shape)
print(ph.get_shape())
replace_input[input.name] = ph
ph_node = OPNode(ph.op)
ops_map[ph.op]=ph_node
for op,op_node in ops_map.items():
new_inputs=[]
for input in op_node.inputs:
input = replace_input.get(input.name,input) #placeholder output
new_inputs.append(input)
op_node.inputs=new_inputs
return ops_map.values()
def get_ops_from_inputs_outputs(graph, inputs,outputs):
ops = graph.get_operations()
ops=remove_ops_before_inputs(inputs.copy(),ops)
ops = ops_to_OPNodes(ops,inputs)
return ops
def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
if save_ori_network:
with open('ori_network.txt','w+') as w:
OPS=graph.get_operations()
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\n')
inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
output_tf =graph.get_tensor_by_name(output_name)
OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] )
with open('network.txt','w+') as w:
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\n')
OPS = sort_ops(OPS)
OPS = merge_layers(OPS)
return OPS
def read_graph_from_ckpt(ckpt_path,input_names,output_name ):
saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
graph = tf.get_default_graph()
with tf.Session( graph=graph) as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess,ckpt_path)
output_tf =graph.get_tensor_by_name(output_name)
pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name])
with tf.Graph().as_default() as g:
tf.import_graph_def(pb_graph, name='')
with tf.Session(graph=g) as sess:
OPS=get_ops_from_pb(g,input_names,output_name)
return OPS
def gen_graph(ops,html_dst):
gb = GraphBuilder(html_dst)
for op in ops:
if not len(op.outputs)>0:
continue
if(op.type=='Placeholder'):
continue
gb.add_op(op )
gb.build()
def print_graph(ops):
for op in ops:
output = op.outputs[0]
print(op.inputs,output)
def read_graph(model_path,input_names,output_name,html_dst):
dir_path = os.path.dirname(html_dst)
if len(dir_path)>0 and not os.path.exists(dir_path):
os.makedirs(dir_path)
if model_path.endswith('pb'):
ops = read_graph_from_pb( model_path ,input_names,output_name)
else:
ops = read_graph_from_ckpt(model_path ,input_names,output_name)
print(dir_path)
gen_graph(ops,html_dst)
if __name__=='__main__':
model_path = sys.argv[1]
input_names = sys.argv[2]
output_name = sys.argv[3]
html_dst = sys.argv[4]
input_names=input_names.split(',')
read_graph(model_path,input_names,output_name,html_dst)
# read_graph('../../mobilenet_v1_1.0_192.ckpt',['batch:0'],'MobilenetV1/Predictions/Reshape_1:0','output/html_dst3.html')
# read_graph( '../../mobilenet_v1_1.0_192_frozen.pb' ,['input:0'],'MobilenetV1/Predictions/Reshape_1:0','output/html_dst1.html')