Skip to content

Commit

Permalink
fix dimensions error for mobilenetv1_KL_quant (#26776)
Browse files Browse the repository at this point in the history
* fix dimensions error for mobilenetv1_KL_quant

fixes AssertionError: The size of weight scales vector (1000) does not match the number of output channels (1024) in the weights tensor fc7_weights.

add mul test

* remove comment

* add third case unit test
  • Loading branch information
sfraczek authored Sep 7, 2020
1 parent 24ec517 commit eb65877
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,14 @@ def _dequantize_op_weights(self, graph, op_node, weight_name, output_name):
# Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_var_name]
weight = self._load_param(self._scope, weight_var_name)
assert scales.size == 1 or scales.size == len(
weight
), "The size of weight scales vector ({}) does not match the number of output channels ({}) in the weights tensor {}.".format(
scales.size, len(weight), weight_var_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
if scales.size == 1 or scales.size == weight.shape[0]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
else:
raise ValueError(
"The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}."
.format(scales.size, weight.shape, weight_var_name))
w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32)
self._restore_var(weight_var_name, w_fp32)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,30 @@ def setUp(self):
self.conv_output = np.ndarray(self.conv_output_size).astype(self.dtype)
self.conv_output2 = np.ndarray(self.conv_output2_size).astype(
self.dtype)
self.quantized_ops = 'conv2d'
self.quantized_ops = 'conv2d,mul'
self.variables = {
"input": self.input,
"filter": self.filter,
"filter2": self.filter2,
"conv_output": self.conv_output,
"conv_output2": self.conv_output2,
}
self.mul_input_size = [1, 3]
self.mul_weights_size = [3, 5]
self.mul_output_size = [1, 5]
self.mul_input = np.random.random(self.mul_input_size).astype(
self.dtype)
self.mul_weights = np.ones(self.mul_weights_size, self.dtype)
self.mul_weights_bad = np.ones([1, 1], self.dtype)
self.mul_output = np.ndarray(self.mul_output_size).astype(self.dtype)
self.mul_output_scale = np.linspace(1, 5, num=5).astype(self.dtype)

self.variables_mul = {
"mul_input": self.mul_input,
"mul_weights": self.mul_weights,
"mul_output": self.mul_output,
"mul_weights_bad": self.mul_weights_bad
}

def prepare_program(self, program):
block = program.global_block()
Expand Down Expand Up @@ -92,6 +108,23 @@ def prepare_program(self, program):
'fuse_brelu': True
})

def prepare_program_mul(self, program):
block = program.global_block()
for name in self.variables_mul:
block.create_var(
name=name,
dtype="float32",
shape=self.variables_mul[name].shape)

mul_op1 = block.append_op(
type="mul",
inputs={
"X": block.var('mul_input'),
"Y": block.var('mul_weights')
},
outputs={"Out": block.var('mul_output')},
attrs={'use_mkldnn': self.use_mkldnn})

def remove_fuse_activation_attribute(self, graph):
for op in graph.all_op_nodes():
op.op().remove_attr("fuse_activation")
Expand All @@ -103,11 +136,13 @@ def check_graph_before_pass(self, graph):

def check_graph_after_pass(self, graph):
for op in graph.all_op_nodes():
self.assertTrue(op.op().has_attr("fuse_activation"))
if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu")
if op.op().has_attr("fuse_brelu") and op.op().attr("fuse_brelu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu6")
if op.op().type() == "conv2d":
self.assertTrue(op.op().has_attr("fuse_activation"))
if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu")
if op.op().has_attr("fuse_brelu") and op.op().attr(
"fuse_brelu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu6")

def test_quant_update_activation(self):
program = fluid.Program()
Expand All @@ -125,6 +160,39 @@ def test_quant_update_activation(self):
graph = quant2_int8_mkldnn_pass._update_activations(graph)
self.check_graph_after_pass(graph)

def test_dequantize_op_weights(self):
program = fluid.Program()
with fluid.program_guard(program):
self.prepare_program_mul(program)
graph = IrGraph(core.Graph(program.desc), for_test=True)

for op in graph.all_op_nodes():
if op.op().type() == "mul":
op_node = op
break

qpass = Quant2Int8MkldnnPass(
self.quantized_ops,
_scope=self.scope,
_place=self.place,
_core=core,
_debug=False)
qpass._weight_scales["mul_output"] = self.mul_output_scale
param = self.scope.var("mul_weights").get_tensor()
param.set(self.variables_mul["mul_weights"], self.place)
qpass._dequantize_op_weights(graph, op_node, "Y", "Out")

assert np.allclose(
self.scope.find_var("mul_weights").get_tensor(),
[[127, 63.5, 42.3333, 31.75, 25.4],
[127, 63.5, 42.3333, 31.75, 25.4],
[127, 63.5, 42.3333, 31.75, 25.4]])

param = self.scope.var("mul_weights").get_tensor()
param.set(self.variables_mul["mul_weights_bad"], self.place)
with self.assertRaises(ValueError):
qpass._dequantize_op_weights(graph, op_node, "Y", "Out")


if __name__ == '__main__':
unittest.main()

0 comments on commit eb65877

Please sign in to comment.