Skip to content

Commit

Permalink
add graph backend (PaddlePaddle#87)
Browse files Browse the repository at this point in the history
based on graphviz
  • Loading branch information
Superjomn authored Jan 12, 2018
1 parent f1339d1 commit cbb7946
Show file tree
Hide file tree
Showing 3 changed files with 395 additions and 132 deletions.
293 changes: 169 additions & 124 deletions visualdl/server/graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import json

from google.protobuf.json_format import MessageToJson

import onnx
import graphviz_graph as gg
from PIL import Image


def debug_print(json_obj):
print(json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': ')))
print(json.dumps(
json_obj, sort_keys=True, indent=4, separators=(',', ': ')))


def reorganize_inout(json_obj, key):
Expand Down Expand Up @@ -78,15 +82,16 @@ def get_links(model_json):
name = input['name']
for node in model_json['node']:
if name in node['input']:
links.append({'source': name,
"target": node['name']})
links.append({'source': name, "target": node['name']})

for source_node in model_json['node']:
for output in source_node['output']:
for target_node in model_json['node']:
if output in target_node['input']:
links.append({'source': source_node['name'],
'target': target_node['name']})
links.append({
'source': source_node['name'],
'target': target_node['name']
})

return links

Expand Down Expand Up @@ -189,8 +194,6 @@ def get_level_to_all(node_links, model_json):
level_to_nodes[level] = list()
level_to_nodes[level].append(idx)
# debug_print(level_to_nodes)


"""
input_to_level {idx -> level}
level_to_inputs {level -> [input1, input2]}
Expand Down Expand Up @@ -231,7 +234,8 @@ def get_level_to_all(node_links, model_json):
if out_level not in output_to_level:
output_to_level[out_idx] = out_level
else:
raise Exception("output " + out_name + "have multiple source")
raise Exception(
"output " + out_name + "have multiple source")
level_to_outputs = dict()
for out_idx in output_to_level:
level = output_to_level[out_idx]
Expand All @@ -243,7 +247,12 @@ def get_level_to_all(node_links, model_json):

def init_level(level):
if level not in level_to_all:
level_to_all[level] = {'nodes': list(), 'inputs': list(), 'outputs': list()}
level_to_all[level] = {
'nodes': list(),
'inputs': list(),
'outputs': list()
}

# merge all levels
for level in level_to_nodes:
init_level(level)
Expand Down Expand Up @@ -321,116 +330,6 @@ def add_edges(json_obj):
return json_obj


def transform_for_echars(model_json):
opItemStyle = {
"normal": {
"color": '#d95f02'
}
}

paraterItemStyle = {
"normal": {
"color": '#1b9e77'
}
};

paraSymbolSize = [12, 6]
paraSymbol = 'rect'
opSymbolSize = [5, 5]

option = {
"title": {
"text": 'Default Graph Name'
},
"tooltip": {
"show": False
},
"animationDurationUpdate": 1500,
"animationEasingUpdate": 'quinticInOut',
"series": [
{
"type": "graph",
"layout": "none",
"symbolSize": 8,
"roam": True,
"label": {
"normal": {
"show": True,
"color": 'black'
}
},
"edgeSymbol": ['none', 'arrow'],
"edgeSymbolSize": [0, 10],
"edgeLabel": {
"normal": {
"textStyle": {
"fontSize": 20
}
}
},
"lineStyle": {
"normal": {
"opacity": 0.9,
"width": 2,
"curveness": 0
}
},
"data": [],
"links": []
}
]
}

option['title']['text'] = model_json['name']

rename_model(model_json)
node_links = get_node_links(model_json)
add_level_to_node_links(node_links)
level_to_all = get_level_to_all(node_links, model_json)
node_to_coordinate, input_to_coordinate, output_to_coordinate = level_to_coordinate(level_to_all)

inputs = model_json['input']
nodes = model_json['node']
outputs = model_json['output']

echars_data = list()

for in_idx in range(len(inputs)):
input = inputs[in_idx]
data = dict()
data['name'] = input['name']
data['x'] = input_to_coordinate[in_idx]['x']
data['y'] = input_to_coordinate[in_idx]['y']
data['symbol'] = paraSymbol
data['itemStyle'] = paraterItemStyle
data['symbolSize'] = paraSymbolSize
echars_data.append(data)
for node_idx in range(len(nodes)):
node = nodes[node_idx]
data = dict()
data['name'] = node['name']
data['x'] = node_to_coordinate[node_idx]['x']
data['y'] = node_to_coordinate[node_idx]['y']
data['itemStyle'] = opItemStyle
data['symbolSize'] = opSymbolSize
echars_data.append(data)
for out_idx in range(len(outputs)):
output = outputs[out_idx]
data = dict()
data['name'] = output['name']
data['x'] = output_to_coordinate[out_idx]['x']
data['y'] = output_to_coordinate[out_idx]['y']
data['symbol'] = paraSymbol
data['itemStyle'] = paraterItemStyle
data['symbolSize'] = paraSymbolSize
echars_data.append(data)

option['series'][0]['data'] = echars_data
option['series'][0]['links'] = get_links(model_json)

return option


def to_IR_json(model_pb_path):
model = onnx.load(model_pb_path)
graph = model.graph
Expand All @@ -446,14 +345,160 @@ def to_IR_json(model_pb_path):

def load_model(model_pb_path):
model_json = to_IR_json(model_pb_path)
options = transform_for_echars(model_json)
return options
model_json = add_edges(model_json)
return model_json


class GraphPreviewGenerator(object):
def __init__(self, model_json):
#self.model = json.loads(model_json)
self.model = model_json
# init graphviz graph
self.graph = gg.Graph(
self.model['name'],
layout="dot",
#resolution=200,
concentrate="true",
# rankdir="LR"
rankdir="TB",
)

self.op_rank = self.graph.rank_group('same', 2)
self.param_rank = self.graph.rank_group('same', 1)
self.arg_rank = self.graph.rank_group('same', 0)

def __call__(self, path='temp.dot'):
self.nodes = {}
self.params = set()
self.ops = set()
self.args = set()

for item in self.model['input'] + self.model['output']:
node = self.add_param(**item)
print 'name', item['name']
self.nodes[item['name']] = node
self.params.add(item['name'])

for id, item in enumerate(self.model['node']):
node = self.add_op(**item)
name = "node_" + str(id)
print 'name', name
self.nodes[name] = node
self.ops.add(name)

for item in self.model['edges']:
source = item['source']
target = item['target']

if source not in self.nodes:
self.nodes[source] = self.add_arg(source)
self.args.add(source)
if target not in self.nodes:
self.nodes[target] = self.add_arg(target)
self.args.add(target)

if source in self.args or target in self.args:
edge = self.add_edge(
style="dashed,bold", color="#aaaaaa", **item)
else:
edge = self.add_edge(style="bold", color="#aaaaaa", **item)

self.graph.display(path)

def add_param(self, name, data_type, shape):
label = '\n'.join([
'<<table cellpadding="5">',
' <tr>',
' <td bgcolor="#eeeeee">',
name,
' </td>'
' </tr>',
' <tr>',
' <td>',
data_type,
' </td>'
' </tr>',
' <tr>',
' <td>',
'[%s]' % 'x'.join(shape),
' </td>'
' </tr>',
'</table>>',
])
return self.graph.node(
label,
prefix="param",
shape="none",
# rank=self.param_rank,
style="rounded,filled,bold",
width="1.3",
#color="#ffa0a0",
color="#8cc7ff",
fontname="Arial")

def add_op(self, opType, **kwargs):
return self.graph.node(
gg.crepr(opType),
# rank=self.op_rank,
prefix="op",
shape="box",
style="rounded, filled, bold",
fillcolor="#8cc7cd",
#fillcolor="#8cc7ff",
fontname="Arial",
width="1.3",
height="0.84",
)

def add_arg(self, name):
return self.graph.node(
gg.crepr(name),
prefix="arg",
# rank=self.arg_rank,
shape="box",
style="rounded,filled,bold",
fontname="Arial",
color="grey")

def add_edge(self, source, target, label, **kwargs):
source = self.nodes[source]
target = self.nodes[target]
return self.graph.edge(source, target, **kwargs)


def draw_graph(model_pb_path, image_dir):
json_str = load_model(model_pb_path)
best_image = None
min_width = None
for i in range(10):
# randomly generate dot images and select the one with minimum width.
g = GraphPreviewGenerator(json_str)
dot_path = os.path.join(image_dir, "temp-%d.dot" % i)
image_path = os.path.join(image_dir, "temp-%d.jpg" % i)
g(dot_path)

try:
im = Image.open(image_path)
if min_width is None or im.size[0] < min_width:
min_width = im.size
best_image = image_path
except:
pass
return best_image


if __name__ == '__main__':
import os
import sys
current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
# json_str = load_model(current_path + "/mock/inception_v1_model.pb")
json_str = load_model(current_path + "/mock/squeezenet_model.pb")
print(json_str)
json_str = load_model(current_path + "/mock/inception_v1_model.pb")
#json_str = load_model(current_path + "/mock/squeezenet_model.pb")
# json_str = load_model('./mock/shufflenet/model.pb')
debug_print(json_str)
assert json_str

g = GraphPreviewGenerator(json_str)
g('./temp.dot')
# for i in range(10):
# g = GraphPreviewGenerator(json_str)
# g('./temp-%d.dot' % i)
Loading

0 comments on commit cbb7946

Please sign in to comment.