Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support quantization of condition block #37498

Merged
merged 11 commits into from
Dec 10, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(self,
weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
is_use_cache_file=False,
quant_blocks=-1,
cache_dir=None):
'''
Constructor.
Expand Down Expand Up @@ -221,6 +222,8 @@ def __init__(self,
be different. In address this problem, fuse the pattern before
quantization. Default False.
is_use_cache_file(bool, optional): This param is deprecated.
quant_blocks(int|list, optional): The bolck id list with quantiazaion.
yghstill marked this conversation as resolved.
Show resolved Hide resolved
Default is -1, it will quant all blocks. And it can be set [0, 1] etc.
cache_dir(str, optional): This param is deprecated.
Returns:
None
Expand Down Expand Up @@ -309,6 +312,7 @@ def __init__(self,
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._is_full_quantize = is_full_quantize
self._quant_blocks = quant_blocks
if is_full_quantize:
self._quantizable_op_type = self._support_quantize_op_type
else:
Expand Down Expand Up @@ -451,10 +455,6 @@ def _load_model_data(self):
model_filename=self._model_filename,
params_filename=self._params_filename)

if self._program.num_blocks > 1:
_logger.error("The post training quantization requires that the "
"program only has one block.")

if self._optimize_model:
self._optimize_fp32_model()

Expand Down Expand Up @@ -505,23 +505,30 @@ def collect_var_name(var_name_list, persistable_var_names, op_type):
self._quantized_act_var_name.add(var_name)

persistable_var_names = _all_persistable_var_names(self._program)
for op in self._program.global_block().ops:
op_type = op.type
if self._is_full_quantize and \
op_type not in self._quantizable_op_type:
_logger.warning(op_type + " is not supported for quantization.")
# For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type:
collect_var_name(
_get_op_input_var_names(op), persistable_var_names, op_type)
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)
for block_id in range(len(self._program.blocks)):
if self._quant_blocks != -1 and isinstance(
self._quant_blocks,
list) and block_id not in self._quant_blocks:
continue
for op in self._program.blocks[block_id].ops:
op_type = op.type
if self._is_full_quantize and \
op_type not in self._quantizable_op_type:
_logger.warning(op_type +
" is not supported for quantization.")
# For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type:
collect_var_name(
_get_op_input_var_names(op), persistable_var_names,
op_type)
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)

def _set_activation_persistable(self):
'''
Expand Down Expand Up @@ -696,16 +703,21 @@ def _save_input_threhold(self):
'''
assert self._algo == "min_max", \
"The algo should be min_max to save input threshold."
for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type:
for var_name in _get_op_input_var_names(op):
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
op._set_attr(var_name + ".min",
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
op._set_attr("with_quant_attr", True)
for block_id in range(len(self._program.blocks)):
if self._quant_blocks != -1 and isinstance(
self._quant_blocks,
list) and block_id not in self._quant_blocks:
continue
for op in self._program.blocks[block_id].ops:
if op.type in self._quantizable_op_type:
for var_name in _get_op_input_var_names(op):
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
op._set_attr(var_name + ".min",
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
op._set_attr("with_quant_attr", True)

def _collect_activation_abs_min_max(self):
'''
Expand Down Expand Up @@ -795,7 +807,13 @@ def _update_program(self):
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
transform_pass.apply(graph)

for i, sub_graph in enumerate(graph.all_sub_graphs()):
yghstill marked this conversation as resolved.
Show resolved Hide resolved
if self._quant_blocks != -1 and isinstance(
self._quant_blocks, list) and i not in self._quant_blocks:
continue
sub_graph._for_test = True
yghstill marked this conversation as resolved.
Show resolved Hide resolved
transform_pass.apply(sub_graph)

# use AddQuantDequantPass to insert fake_quant_dequant op
minor_quantizable_op_types = []
Expand All @@ -806,7 +824,13 @@ def _update_program(self):
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph)

for i, sub_graph in enumerate(graph.all_sub_graphs()):
if self._quant_blocks != -1 and isinstance(
self._quant_blocks, list) and i not in self._quant_blocks:
continue
sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph)

# save threshold to scale var node
if self._algo in ["KL", "hist"]:
Expand Down Expand Up @@ -836,7 +860,14 @@ def _update_program(self):
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
freeze_pass.apply(graph)

for i, sub_graph in enumerate(graph.all_sub_graphs()):
if self._quant_blocks != -1 and isinstance(
self._quant_blocks, list) and i not in self._quant_blocks:
continue
sub_graph._for_test = True
freeze_pass.apply(sub_graph)

self._program = graph.to_program()

def _save_output_threshold(self):
Expand Down Expand Up @@ -888,13 +919,19 @@ def analysis_and_save_info(op_node, out_var_name):
save_info(op_node, out_var_name, self._quantized_var_max,
"out_max", "post_min_max")

for op in self._program.global_block().ops:
if op.type in (self._quantizable_op_type + self._out_scale_op_list):
out_var_names = _get_op_output_var_names(op)
assert len(out_var_names) == 1, "Post training " + \
"quantization only support one output for " + op.type
for var_name in out_var_names:
analysis_and_save_info(op, var_name)
for block_id in range(len(self._program.blocks)):
if self._quant_blocks != -1 and isinstance(
self._quant_blocks,
list) and block_id not in self._quant_blocks:
continue
for op in self._program.blocks[block_id].ops:
if op.type in (
self._quantizable_op_type + self._out_scale_op_list):
out_var_names = _get_op_output_var_names(op)
assert len(out_var_names) == 1, "Post training " + \
"quantization only support one output for " + op.type
for var_name in out_var_names:
analysis_and_save_info(op, var_name)

def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
"""
Expand Down
33 changes: 22 additions & 11 deletions python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,8 @@ def _create_global_step(self, graph):
attrs={
'step': 1.0,
'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
yghstill marked this conversation as resolved.
Show resolved Hide resolved
},
inputs={'X': global_step_in},
outputs={'Out': global_step_out})
Expand Down Expand Up @@ -632,7 +633,8 @@ def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits):
op_type='fake_quantize_abs_max',
attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
yghstill marked this conversation as resolved.
Show resolved Hide resolved
},
inputs={'X': var_node},
outputs={'Out': quant_var_node,
Expand Down Expand Up @@ -694,7 +696,8 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits):
'window_size': self._window_size,
'bit_length': quant_bits,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max',
Expand Down Expand Up @@ -778,7 +781,8 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node, name,
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
}

quant_op_node = graph.create_op_node(
Expand Down Expand Up @@ -831,7 +835,8 @@ def _insert_channel_quant_op(self, graph, var_node, name, quant_bits,
'bit_length': quant_bits,
'quant_axis': quant_axis,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
},
inputs={'X': var_node},
outputs={'Out': quant_var_node,
Expand All @@ -857,7 +862,8 @@ def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
op_type='fake_dequantize_max_abs',
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
},
inputs={'X': var_node,
'Scale': scale_var_node},
Expand All @@ -884,7 +890,8 @@ def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
attrs={
'quant_bits': quant_bits,
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
},
inputs={'X': var_node,
'Scales': scale_var_nodes},
Expand Down Expand Up @@ -1303,7 +1310,8 @@ def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis):
'quant_bits': [self._weight_bits, self._activation_bits],
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'x_num_col_dims': x_num_col_dims
'x_num_col_dims': x_num_col_dims,
'op_device': ''
},
inputs={
'X': output_var_node,
Expand Down Expand Up @@ -1359,7 +1367,8 @@ def _insert_post_dequant_op(self, graph, op_node):
op_type='fake_dequantize_max_abs',
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
},
inputs={'X': output_var_node,
'Scale': scale_var_node},
Expand Down Expand Up @@ -1707,7 +1716,8 @@ def apply(self, graph):
attrs = {
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
Expand Down Expand Up @@ -1987,7 +1997,8 @@ def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_device': ''
}

quant_op_node = graph.create_op_node(
Expand Down