From 0ee47819e0c657badf6055019adde9aa1eb9fde7 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 11 Aug 2023 09:07:13 +0200 Subject: [PATCH] fix for mobilent ptq --- core/src/ops/cnn/conv/unary.rs | 5 +++-- tflite/src/ops/cnn.rs | 4 ++-- tflite/src/ops/mod.rs | 9 +++++++-- tflite/src/ops/nn.rs | 1 + 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/core/src/ops/cnn/conv/unary.rs b/core/src/ops/cnn/conv/unary.rs index 5c32cf79b4..72db790957 100644 --- a/core/src/ops/cnn/conv/unary.rs +++ b/core/src/ops/cnn/conv/unary.rs @@ -669,8 +669,9 @@ impl ConvUnary { let mut new = self.clone(); new.pool_spec.padding = padding; let mut patch = TypedModelPatch::default(); - let wire = patch.tap_model(model, prec.inputs[0])?; - let wire = patch.wire_node(&node.name, new, &[wire])?; + let mut wire = patch.taps(model, &node.inputs)?; + wire[0] = patch.tap_model(model, prec.inputs[0])?; + let wire = patch.wire_node(&node.name, new, &wire)?; patch.shunt_outside(model, node.id.into(), wire[0])?; Ok(Some(patch)) } diff --git a/tflite/src/ops/cnn.rs b/tflite/src/ops/cnn.rs index 097a27758c..0711ded75a 100644 --- a/tflite/src/ops/cnn.rs +++ b/tflite/src/ops/cnn.rs @@ -173,7 +173,7 @@ fn de_conv2d(op: &mut DeserOp) -> TractResult> { output_channel_override: Some(*co), }; let mut inputs = tvec!(op.inputs[0]); - let q_params = super::linearops_quantization_suport(op, &input, &mut inputs)?; + let q_params = super::linearops_quantization_suport(op, &input, &mut inputs, true)?; let bias_dt = bias.datum_type().unquantized(); let bias = bias.into_tensor().cast_to_dt(bias_dt)?.into_owned().into_arc_tensor(); let conv = core::cnn::ConvUnary { @@ -213,7 +213,7 @@ fn de_dw_conv2d(op: &mut DeserOp) -> TractResult> { output_channel_override: Some(co), }; let mut inputs = tvec!(op.inputs[0]); - let q_params = super::linearops_quantization_suport(op, &input, &mut inputs)?; + let q_params = super::linearops_quantization_suport(op, &input, &mut inputs, true)?; let conv = core::cnn::ConvUnary { pool_spec, kernel_fmt: KernelFormat::OHWI, diff --git a/tflite/src/ops/mod.rs b/tflite/src/ops/mod.rs index 7fabda1e08..f0c2132bef 100644 --- a/tflite/src/ops/mod.rs +++ b/tflite/src/ops/mod.rs @@ -55,6 +55,7 @@ fn linearops_quantization_suport( op: &mut DeserOp, input: &TypedFact, inputs: &mut TVec, + kscale_is_per_axis: bool, ) -> TractResult> { if op.output_facts[0].datum_type.is_quantized() { let p = &op.prefix; @@ -63,11 +64,15 @@ fn linearops_quantization_suport( let k_input = op.flat.inputs().unwrap().get(1); let k_tensor = op.ctx.subgraph.tensors().unwrap().get(k_input as usize); let k_qp = k_tensor.quantization().unwrap(); - let kscale = k_qp.scale().unwrap().iter().collect_vec(); + let k_scale = if kscale_is_per_axis { + rctensor1(&k_qp.scale().unwrap().iter().collect_vec()) + } else { + rctensor0(k_qp.scale().unwrap().get(0)) + }; let k_zp = k_qp.zero_point().unwrap().iter().map(|i| i as i32).collect_vec(); ensure!(k_zp.iter().all(|x| *x == 0)); inputs.push(op.ctx.target.add_const(format!("{p}.k0"), rctensor0(0i8))?); - inputs.push(op.ctx.target.add_const(format!("{p}.kscale"), rctensor1(&kscale))?); + inputs.push(op.ctx.target.add_const(format!("{p}.kscale"), k_scale)?); inputs.push(op.ctx.target.add_const(format!("{p}.i0"), rctensor0(iqp.zp_scale().0 as i8))?); inputs.push(op.ctx.target.add_const(format!("{p}.iscale"), rctensor0(iqp.zp_scale().1))?); inputs.push(op.ctx.target.add_const(format!("{p}.c0"), rctensor0(oqp.zp_scale().0 as i8))?); diff --git a/tflite/src/ops/nn.rs b/tflite/src/ops/nn.rs index 0912650c81..d459b5c2ff 100644 --- a/tflite/src/ops/nn.rs +++ b/tflite/src/ops/nn.rs @@ -31,6 +31,7 @@ fn de_fully_connected(op: &mut DeserOp) -> TractResult> { op, &input, &mut inputs, + false, )?; let operating_dt = if input.datum_type.is_float() { input.datum_type } else { i32::datum_type() };