diff --git a/paddle2onnx/mapper/quantize/rknn_quantize_processor.cc b/paddle2onnx/mapper/quantize/rknn_quantize_processor.cc index 1bc019ba2..af7af5f68 100644 --- a/paddle2onnx/mapper/quantize/rknn_quantize_processor.cc +++ b/paddle2onnx/mapper/quantize/rknn_quantize_processor.cc @@ -90,23 +90,33 @@ void RKNNQuantizeProcessor::AddQDQ() { if (helper_->quantize_info.find(name) != helper_->quantize_info.end()) { continue; } - std::vector matmul_weight; - if (!GetTensorByName(name, &matmul_weight)) { + std::vector weight_data; + if (!GetTensorByName(name, &weight_data)) { continue; } - std::vector matmul_weight_shape; - if (!GetTensorShape(name, &matmul_weight_shape)) { + std::vector weight_shape; + if (!GetTensorShape(name, &weight_shape)) { continue; } + int64_t quantize_axis = 1; std::vector scale; std::vector zeros; - GetChannelWiseQuantizeInfo(matmul_weight, matmul_weight_shape, - quantize_axis, &scale, &zeros); - auto scale_node = - helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale); - auto zero_node = - helper_->Constant(ONNX_NAMESPACE::TensorProto::INT8, zeros); + std::string scale_node, zero_node; + if (weight_shape.size() <= 1) { + GetTensorWiseQuantizeInfo(weight_data, &scale, &zeros); + scale_node = helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, + scale[0]); + zero_node = helper_->Constant({}, ONNX_NAMESPACE::TensorProto::INT8, + zeros[0]); + } else { + GetChannelWiseQuantizeInfo(weight_data, weight_shape, quantize_axis, + &scale, &zeros); + scale_node = + helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale); + zero_node = + helper_->Constant(ONNX_NAMESPACE::TensorProto::INT8, zeros); + } QuantizeInfo matmul_weight_quantize_info(scale, zeros, scale_node, zero_node, quantize_axis); helper_->quantize_info[name] = matmul_weight_quantize_info; @@ -169,13 +179,11 @@ void RKNNQuantizeProcessor::PerchannelToPerlayer() { auto next_nodes = name2node_dict_[node->output(0)]; if (next_nodes.size() > 1 || IsGraphOutput(node->output(0))) { - P2OLogger() << "Type1" << std::endl; continue; } auto add_node = next_nodes[0]; if (add_node->op_type() != "Add" || IsGraphOutput(add_node->output(0))) { - P2OLogger() << "Type2" << std::endl; continue; }