Skip to content

Commit

Permalink
Fix flops api. (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang authored Dec 19, 2019
1 parent 3151326 commit 09f99c9
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions paddleslim/analysis/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@ def _graph_flops(graph, only_conv=False, detail=False):
for op in graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d']:
filter_shape = op.inputs("Filter")[0].shape()
input_shape = op.inputs("Input")[0].shape()
output_shape = op.outputs("Output")[0].shape()
_, c_in, _, _ = input_shape
c_out, _, k_h, k_w = filter_shape
c_out, c_in, k_h, k_w = filter_shape
_, _, h_out, w_out = output_shape
groups = op.attr("groups")
kernel_ops = k_h * k_w * (float(c_in) / groups)
# c_in is the channel number of filter. It is (input_channel // groups).
kernel_ops = k_h * k_w * float(c_in)
if len(op.inputs("Bias")) > 0:
with_bias = 1
else:
Expand All @@ -50,7 +48,6 @@ def _graph_flops(graph, only_conv=False, detail=False):
flops += op_flops
params2flops[op.inputs("Filter")[0].name()] = op_flops
elif op.type() == 'pool2d' and not only_conv:
input_shape = op.inputs("X")[0].shape()
output_shape = op.outputs("Out")[0].shape()
_, c_out, h_out, w_out = output_shape
k_size = op.attr("ksize")
Expand Down

0 comments on commit 09f99c9

Please sign in to comment.