diff --git a/.travis/regular-tests.sh b/.travis/regular-tests.sh index f1dc649223..1466b87dc4 100755 --- a/.travis/regular-tests.sh +++ b/.travis/regular-tests.sh @@ -65,9 +65,17 @@ then exit 0 fi +OLD_CACHEDIR=$CACHEDIR +mkdir -p $CACHEDIR/big +export CACHEDIR=$CACHEDIR/big cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p core-proptest-pulse $ALL_FEATURES cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p lstm-proptest-onnx-vs-tf $ALL_FEATURES cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p nnef-inceptionv3 $ALL_FEATURES cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p tf-inceptionv3 $ALL_FEATURES cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p tf-mobilenet-v2 $ALL_FEATURES cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p tf-moz-deepspeech $ALL_FEATURES +if [ -n "$GITHUB_ACTIONS" ] +then + rm -r $OLD_CACHEDIR/big +fi +CACHEDIR=$OLD_CACHEDIR diff --git a/core/src/ops/cnn/conv/unary.rs b/core/src/ops/cnn/conv/unary.rs index 324d9b3926..5c32cf79b4 100644 --- a/core/src/ops/cnn/conv/unary.rs +++ b/core/src/ops/cnn/conv/unary.rs @@ -159,7 +159,7 @@ impl ConvUnary { use crate::ops::matmul::mir_quant as qmm; let c_dt = self.q_params.unwrap(); - let [a0, a_scale, mut b0, b_scale, c0, c_scale] = wires[1..] else { + let [a0, mut a_scale, mut b0, b_scale, c0, c_scale] = wires[1..] else { bail!("Wrong number of inputs") }; @@ -168,6 +168,18 @@ impl ConvUnary { let (_, m, k, n, mmm) = self.compute_geo(&b_fact)?; let output_shape = self.pool_spec.output_shape(&b_fact.shape)?; + if !model.outlet_fact(a_scale)?.shape.volume().is_one() { + // requant is performed before geo_reshape, so we need at most one geo axis to the + // right + if !output_shape.fmt.c_is_last() { + a_scale = model.wire_node( + format!("{name}.a_scale_axis_fix"), + AxisOp::Add(1), + &[a_scale], + )?[0]; + } + } + let abc_scale = qmm::combine_scales(model, name, a_scale, b_scale, c_scale)?; let im2col = model.wire_node( @@ -641,7 +653,7 @@ impl ConvUnary { return Ok(None); }; let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?; - if value.cast_to_scalar::()? != 0 + if !value.is_zero()? || (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0)) || pad.pads[shape.c_axis()] != (0, 0) { @@ -671,27 +683,21 @@ impl ConvUnary { if self.q_params.is_some() || self.group != 1 { return Ok(None); } - let &[succ] = &*node.outputs[0].successors else { - return Ok(None) - }; - let Some(bin) = model.node(succ.node).op_as::() else { - return Ok(None) - }; + let &[succ] = &*node.outputs[0].successors else { return Ok(None) }; + let Some(bin) = model.node(succ.node).op_as::() else { return Ok(None) }; let other_input = model.node(succ.node).inputs[1 - succ.slot]; let other_fact = &model.outlet_fact(other_input)?; - let Some(konst) = &other_fact.konst else { - return Ok(None) - }; + let Some(konst) = &other_fact.konst else { return Ok(None) }; let axes_mapping = model.node_axes_mapping(succ.node)?; let input_shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?; let conv_c_axis = input_shape.c_axis(); - let &[konst_c_axis] = &*axes_mapping.axis((InOut::In(succ.slot), conv_c_axis))?.inputs[1- succ.slot] else { - return Ok(None) - }; - let Ok(co) = node.outputs[0].fact.shape[conv_c_axis].to_usize() else { - return Ok(None) + let &[konst_c_axis] = + &*axes_mapping.axis((InOut::In(succ.slot), conv_c_axis))?.inputs[1 - succ.slot] + else { + return Ok(None); }; + let Ok(co) = node.outputs[0].fact.shape[conv_c_axis].to_usize() else { return Ok(None) }; let operand_for_bias = if konst.shape()[konst_c_axis] == co && konst.len() == co { konst.clone().into_tensor().into_shape(&[co])? } else if konst.len() == 1 { diff --git a/core/src/ops/einsum/codegen.rs b/core/src/ops/einsum/codegen.rs index 884b4a2f8d..910fd06263 100644 --- a/core/src/ops/einsum/codegen.rs +++ b/core/src/ops/einsum/codegen.rs @@ -93,7 +93,13 @@ pub(super) fn ensure_mkn_axes<'a>( }) .max_by_key(|a| &output_shape[a.outputs[0][0]]); let Some(n_axis) = n_axis else { - return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, true, &[k_axis, m_axis])?)); + return Ok(AxesOrPatch::Patch(inject_m_or_n_axis( + op, + model, + node, + true, + &[k_axis, m_axis], + )?)); }; for axis in op.axes.iter_all_axes() { let one = TDim::one(); @@ -218,10 +224,22 @@ fn dequant_output( let name = &node.name; let mut patch = TypedModelPatch::new("Dequantizing einsum"); let taps = patch.taps(model, &node.inputs)?; - let [a, b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else { + let [a, b, bias, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = *taps else { bail!("Expect exactly 9 inputs") }; + if !patch.outlet_fact(a_scale)?.shape.volume().is_one() { + let q_axis_in_output = op.axes.axis((InOut::In(4), 0))?.outputs[0][0]; + let output_rank = node.outputs[0].fact.rank(); + for i in 1..(output_rank - q_axis_in_output) { + a_scale = patch.wire_node( + format!("{name}.a_scale_axis_fix_{i}"), + AxisOp::Add(i), + &[a_scale], + )?[0]; + } + } + let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut a0, "a0")?; let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut b0, "b0")?; diff --git a/data/src/datum.rs b/data/src/datum.rs index be1a275c20..432d48c7d4 100644 --- a/data/src/datum.rs +++ b/data/src/datum.rs @@ -304,6 +304,16 @@ impl DatumType { _ => *self, } } + + pub fn quantize(&self, qparams: QParams) -> DatumType { + match self { + DatumType::I8 => DatumType::QI8(qparams), + DatumType::U8 => DatumType::QI8(qparams), + DatumType::I32=> DatumType::QI32(qparams), + _ => panic!("Can't quantize {self:?}"), + } + } + #[inline(always)] pub fn zp_scale(&self) -> (i32, f32) { self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.)) diff --git a/data/src/tensor.rs b/data/src/tensor.rs index 73da4636ce..1c486ef9e8 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -39,9 +39,10 @@ impl Approximation { fn atol_and_rtol(&self, dt: &DatumType) -> (f64, f64) { use Approximation::*; match (self, dt) { + (Exact, _) => (0.0, 0.0), (Close, DatumType::F16) => (1e-3, 1e-3), (Approximate, DatumType::F16) => (1e-3, 5e-3), - (Exact, _) => (0.0, 0.0), + (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0.), (Close, _) => (1e-7, 1e-7), (Approximate, _) => (1e-4, 5e-4), } diff --git a/test-rt/infra/src/lib.rs b/test-rt/infra/src/lib.rs index 72018299bf..62f41f17ce 100644 --- a/test-rt/infra/src/lib.rs +++ b/test-rt/infra/src/lib.rs @@ -6,6 +6,7 @@ use dyn_clone::DynClone; use itertools::Itertools; use proptest::prelude::{any_with, Arbitrary}; use proptest::test_runner::{Config, FileFailurePersistence, TestRunner}; +use tract_core::internal::Approximation; use tract_core::runtime::Runtime; use tract_core::tract_data::TractResult; @@ -16,7 +17,11 @@ pub fn setup_test_logger() { pub type TestResult = anyhow::Result<()>; pub trait Test: 'static + Send + Sync + DynClone { - fn run(&self, runtime: &dyn Runtime) -> TestResult; + fn run(&self, id: &str, runtime: &dyn Runtime) -> TestResult { + self.run_with_approx(id, runtime, Approximation::Close) + } + fn run_with_approx(&self, id: &str, runtime: &dyn Runtime, approx: Approximation) + -> TestResult; } dyn_clone::clone_trait_object!(Test); @@ -87,6 +92,32 @@ impl TestSuite { } } + pub fn get_sub(&self, id: &str) -> &TestSuite { + match self { + TestSuite::Node(n) => { + if let Some((head, tail)) = id.split_once("::") { + n[head].get_sub(tail) + } else { + n[id].get_sub("") + } + } + TestSuite::Leaf(_, _) => panic!(), + } + } + + pub fn get_sub_mut(&mut self, id: &str) -> &mut TestSuite { + match self { + TestSuite::Node(n) => { + if let Some((head, tail)) = id.split_once("::") { + n.get_mut(head).unwrap().get_sub_mut(tail) + } else { + n.get_mut(id).unwrap() + } + } + TestSuite::Leaf(_, _) => panic!(), + } + } + fn ignore_rec(&mut self, prefix: &mut Vec, ign: &dyn Fn(&[String]) -> bool) { match self { TestSuite::Node(n) => { @@ -111,6 +142,7 @@ impl TestSuite { prefix: &str, id: &str, rs: &mut impl Write, + approx: &str, ) -> TractResult<()> { let full_id = [prefix, id].into_iter().filter(|s| s.len() > 0).join("::"); match self { @@ -120,7 +152,7 @@ impl TestSuite { writeln!(rs, "use super::*;").unwrap(); } for (id, test) in h.iter().sorted_by_key(|(k, _)| k.to_owned()) { - test.dump(test_suite, runtime, &full_id, id, rs)?; + test.dump(test_suite, runtime, &full_id, id, rs, approx)?; } if id.len() > 0 { writeln!(rs, "}}").unwrap(); @@ -133,21 +165,25 @@ impl TestSuite { writeln!(rs, "#[ignore]").unwrap(); } writeln!(rs, "fn {id}() -> TractResult<()> {{",).unwrap(); - writeln!(rs, " {test_suite}.get({full_id:?}).run({runtime})",).unwrap(); + writeln!( + rs, + " {test_suite}.get({full_id:?}).run_with_approx({full_id:?}, {runtime}, {approx})", + ) + .unwrap(); writeln!(rs, "}}").unwrap(); } } Ok(()) } - pub fn test_runtime(&self, name: &str, test_suite: &str, runtime: &str) { + pub fn test_runtime(&self, name: &str, test_suite: &str, runtime: &str, approx: &str) { let out_dir = std::env::var("OUT_DIR").unwrap(); let out_dir = std::path::PathBuf::from(out_dir); let test_dir = out_dir.join("tests"); std::fs::create_dir_all(&test_dir).unwrap(); let test_file = test_dir.join(name).with_extension("rs"); let mut rs = std::fs::File::create(test_file).unwrap(); - self.dump(test_suite, runtime, "", "", &mut rs).unwrap(); + self.dump(test_suite, runtime, "", "", &mut rs, approx).unwrap(); } } @@ -160,13 +196,13 @@ impl Test for ProptestWrapper where A::Parameters: Clone + Send + Sync, { - fn run(&self, runtime: &dyn Runtime) -> TestResult { + fn run_with_approx(&self, id: &str, runtime: &dyn Runtime, approx: Approximation) -> TestResult { let mut runner = TestRunner::new(Config { failure_persistence: Some(Box::new(FileFailurePersistence::Off)), ..Config::default() }); runner.run(&any_with::(self.0.clone()), |v| { - v.run(runtime).unwrap(); + v.run_with_approx(id, runtime, approx).unwrap(); Ok(()) })?; Ok(()) diff --git a/test-rt/suite-conv/src/conv_f32.rs b/test-rt/suite-conv/src/conv_f32.rs index 5ad2b9e7a6..ee5a4b564d 100644 --- a/test-rt/suite-conv/src/conv_f32.rs +++ b/test-rt/suite-conv/src/conv_f32.rs @@ -242,12 +242,18 @@ impl Arbitrary for ConvProblem { } impl Test for ConvProblem { - fn run(&self, runtime: &dyn Runtime) -> TestResult { + fn run_with_approx( + &self, + id: &str, + runtime: &dyn Runtime, + approx: Approximation, + ) -> TestResult { let reference = self.reference().into_tensor(); - let mut output = - runtime.prepare(self.tract()?)?.run(tvec![self.data.clone().into_tvalue()])?; + let mut model = self.tract()?; + model.properties.insert("tract-rt-test.id".to_string(), rctensor0(id.to_string())); + let mut output = runtime.prepare(model)?.run(tvec![self.data.clone().into_tvalue()])?; let output = output.remove(0).into_tensor(); - output.close_enough(&reference, true) + output.close_enough(&reference, approx) } } diff --git a/test-rt/suite-conv/src/conv_q.rs b/test-rt/suite-conv/src/conv_q.rs index e22f0d6cd8..cf1afb97ff 100644 --- a/test-rt/suite-conv/src/conv_q.rs +++ b/test-rt/suite-conv/src/conv_q.rs @@ -8,22 +8,58 @@ use tract_core::ops::cnn::{ConvUnary, KernelFormat, PaddingSpec, PoolSpec}; use tract_core::ops::math::round_ties_to_even; use tract_core::ops::nn::DataFormat::*; use tract_core::ops::nn::DataShape; +use tract_core::tract_data::itertools::Itertools; use tract_itertools::izip; use tract_ndarray::*; +use crate::conv_f32::ConvProblemParams; + pub fn qtensor(shape: Vec) -> BoxedStrategy> { let len = shape.iter().product::(); vec(any::(), len..=len) .prop_map(move |vec| ArrayD::from_shape_vec(shape.clone(), vec).unwrap()) .boxed() } +/* https://www.tensorflow.org/lite/performance/quantization_spec +CONV_2D +Input 0: +data_type : int8 +range : [-128, 127] +granularity: per-tensor +Input 1 (Weight): +data_type : int8 +range : [-127, 127] +granularity: per-axis (dim = 0) +restriction: zero_point = 0 +Input 2 (Bias): +data_type : int32 +range : [int32_min, int32_max] +granularity: per-axis +restriction: (scale, zero_point) = (input0_scale * input1_scale[...], 0) +Output 0: +data_type : int8 +range : [-128, 127] +granularity: per-tensor +*/ -pub fn q_params() -> BoxedStrategy<[Tensor; 6]> { - (-10i32..10, -10i32..10, -10i32..10, -3..3i32, -3..3i32, -3..3i32) +pub fn q_params(params: &QConvProblemParams, co: usize) -> BoxedStrategy<[Tensor; 6]> { + let a0 = if params.no_kernel_zero_point { Just(0i32).boxed() } else { (-10..10i32).boxed() }; + ( + a0, + -10i32..10, + -10i32..10, + prop_oneof![ + (Just(false), (-3..3i32).prop_map(|x| vec!(x)).boxed()), + (Just(true), vec(-3..3i32, co..=co).boxed()) + ], + -3..3i32, + -3..3i32, + ) .prop_map(|(a0, b0, c0, a_scale, b_scale, c_scale)| { + let a_scale_values = a_scale.1.iter().map(|x| 2f32.powi(*x)).collect_vec(); [ tensor0(a0), - tensor0(2f32.powi(a_scale)), + if a_scale.0 { tensor1(&a_scale_values) } else { tensor0(a_scale_values[0]) }, tensor0(b0), tensor0(2f32.powi(b_scale)), tensor0(c0), @@ -33,8 +69,14 @@ pub fn q_params() -> BoxedStrategy<[Tensor; 6]> { .boxed() } +#[derive(Debug, Clone, Default)] +pub struct QConvProblemParams { + pub conv: ConvProblemParams, + pub no_kernel_zero_point: bool, +} + #[derive(Debug, Clone)] -struct QConvProblem { +pub struct QConvProblem { shape_in: DataShape, kernel_format: KernelFormat, co: usize, @@ -50,7 +92,7 @@ impl QConvProblem { &self.kernel.shape()[self.kernel_format.h_axis()..][..self.shape_in.hw_rank()] } - fn reference(&self) -> ArrayD { + fn reference(&self) -> Tensor { assert_eq!(self.data.shape(), &*self.shape_in.shape); let n = *self.shape_in.n().unwrap_or(&1); let ci_per_g = self.shape_in.c() / self.group; @@ -58,9 +100,8 @@ impl QConvProblem { let a0 = self.qp[0].cast_to_scalar::().unwrap(); let b0 = self.qp[2].cast_to_scalar::().unwrap(); let c0 = self.qp[4].cast_to_scalar::().unwrap(); - let scale = self.qp[5].cast_to_scalar::().unwrap() - / self.qp[1].cast_to_scalar::().unwrap() - / self.qp[3].cast_to_scalar::().unwrap(); + let b_scale = self.qp[3].cast_to_scalar::().unwrap(); + let c_scale = self.qp[5].cast_to_scalar::().unwrap(); let shape_out: TVec = izip!(self.shape_in.hw_dims(), self.geo_ker()) .map(|(i, k)| (*i + 1).saturating_sub(*k)) .collect(); @@ -69,6 +110,12 @@ impl QConvProblem { .fmt .from_n_c_hw(self.shape_in.n().cloned().unwrap_or(1), co_per_g * self.group, shape_out) .unwrap(); + // a is the kernel, it can be quantized per O axis + let a_scale = if self.qp[1].len() == 1 { + vec![self.qp[1].cast_to_scalar::().unwrap(); *shape_out.c()] + } else { + self.qp[1].as_slice::().unwrap().into() + }; let mut temp = ArrayD::::zeros(&*shape_out.shape); for n in 0..n { for g in 0..self.group { @@ -118,21 +165,45 @@ impl QConvProblem { shape[shape_out.c_axis()] = bias.len(); temp += &bias.clone().into_shape(shape).unwrap(); } - temp.mapv(|i| { - (round_ties_to_even(i as f32 / scale) as i32 + c0) - .max(std::i8::MIN as i32) - .min(std::i8::MAX as i32) as i8 - }) + temp.axis_iter_mut(Axis(shape_out.c_axis())).zip(a_scale).for_each( + |(mut view, a_scale)| { + view.mapv_inplace(|i| { + (round_ties_to_even(i as f32 / c_scale * a_scale * b_scale) as i32 + c0) + .max(std::i8::MIN as i32) + .min(std::i8::MAX as i32) + }) + }, + ); + temp.into_tensor() + .cast_to_dt( + i8::datum_type().quantize(QParams::ZpScale { zero_point: c0, scale: c_scale }), + ) + .unwrap() + .into_owned() } fn tract(&self) -> TractResult { - assert_eq!(self.data.shape(), &*self.shape_in.shape); + assert!(self.data.shape() == &*self.shape_in.shape); let mut model = TypedModel::default(); - let wire = model.add_source("input", i8::fact(&self.shape_in.shape))?; + let kdt = DatumType::QI8(QParams::ZpScale { + zero_point: self.qp[0].cast_to_scalar()?, + scale: *self.qp[1].to_scalar()?, + }); + let idt = DatumType::QI8(QParams::ZpScale { + zero_point: self.qp[2].cast_to_scalar()?, + scale: *self.qp[3].to_scalar()?, + }); + let cdt = DatumType::QI8(QParams::ZpScale { + zero_point: self.qp[4].cast_to_scalar()?, + scale: *self.qp[5].to_scalar()?, + }); + let wire = model.add_source("input", idt.fact(&self.shape_in.shape))?; let mut inputs = tvec!(wire); for (ix, qp) in self.qp.iter().enumerate() { inputs.push(model.add_const(format!("qp.{ix}"), qp.clone())?); } + let mut kernel = self.kernel.clone().into_tensor(); + unsafe { kernel.set_datum_type(kdt) }; let op = ConvUnary::new( PoolSpec::new( self.shape_in.fmt, @@ -143,10 +214,10 @@ impl QConvProblem { Some(self.co), ), self.kernel_format, - self.kernel.clone().into_arc_tensor(), + kernel.into_arc_tensor(), self.group, self.bias.clone().map(|a| a.into_arc_tensor()), - Some(i8::datum_type()), + Some(cdt), ); let wire = model.wire_node("conv", op, &inputs)?[0]; model.set_output_outlets(&[wire])?; @@ -155,71 +226,83 @@ impl QConvProblem { } impl Test for QConvProblem { - fn run(&self, runtime: &dyn Runtime) -> TractResult<()> { + fn run_with_approx( + &self, + id: &str, + runtime: &dyn Runtime, + approx: Approximation, + ) -> infra::TestResult { + let reference = self.reference(); + let mut model = self.tract()?; + model.properties.insert("tract-rt-test.id".to_string(), rctensor0(id.to_string())); let model = runtime.prepare(self.tract()?)?; - let output = model.run(tvec!(self.data.clone().into_tensor().into_tvalue()))?.remove(0); - output.close_enough(&self.reference().into_tensor(), Approximation::Exact) + let idt = DatumType::QI8(QParams::ZpScale { + zero_point: self.qp[2].cast_to_scalar()?, + scale: *self.qp[3].to_scalar()?, + }); + let data = self.data.clone().into_tensor().cast_to_dt(idt)?.into_owned().into_tvalue(); + let output = model.run(tvec!(data))?.remove(0); + eprintln!("reference: {reference:?}\noutput : {output:?}"); + output.close_enough(&reference, approx) } } impl Arbitrary for QConvProblem { - type Parameters = (); + type Parameters = QConvProblemParams; type Strategy = BoxedStrategy; - fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + let geo_rank = params.conv.geo_rank.clone().unwrap_or(1..4); ( crate::data_format(), crate::kernel_format(), 1usize..=10, 1usize..=8, 1usize..=8, - 1usize..=3, - (1usize..=3).prop_flat_map(crate::shapes), - q_params(), + 1usize..=(if params.conv.no_group { 1 } else { 3 }), + geo_rank.prop_flat_map(crate::shapes), ) - .prop_flat_map(|(df, kf, n, mut ci0, co0, group, (mut ker_shape, data_shape), qp)| { - // FIXME in HWIO order, only regular and depthwise are supported - if kf == KernelFormat::HWIO && group > 1 { - ci0 = 1; - } - let shape_in = df.from_n_c_hw(n, ci0 * group, data_shape).unwrap(); - let data_in = qtensor(shape_in.shape.iter().cloned().collect()); - match kf { - KernelFormat::HWIO => { - ker_shape.push(ci0 * group); - ker_shape.push(co0) + .prop_flat_map( + move |(df, kf, n, mut ci0, mut co0, group, (mut ker_shape, data_shape))| { + // FIXME in HWIO order, only regular and depthwise are supported + if params.conv.no_arbitrary_grouping && group > 1 { + ci0 = 1; + co0 = 1; } - KernelFormat::OIHW => { - ker_shape.insert(0, ci0); - ker_shape.insert(0, co0 * group) + if kf == KernelFormat::HWIO && group > 1 { + ci0 = 1; } - KernelFormat::OHWI => { - ker_shape.insert(0, co0); - ker_shape.push(ci0 * group) - } - }; - let kernel = qtensor(ker_shape); - let bias = proptest::option::of( - qtensor(vec![co0 * group]).prop_map(|a| a.mapv(|v| v as i32)), - ); - (Just((kf, shape_in, co0 * group, group, qp)), data_in, kernel, bias) - // FIXME - }) - .prop_map(|((kernel_format, shape_in, co, group, qp), data, kernel, bias)| { + let qp = q_params(¶ms, co0 * group); + let shape_in = df.from_n_c_hw(n, ci0 * group, data_shape).unwrap(); + let data_in = qtensor(shape_in.shape.iter().cloned().collect()); + match kf { + KernelFormat::HWIO => { + ker_shape.push(ci0 * group); + ker_shape.push(co0) + } + KernelFormat::OIHW => { + ker_shape.insert(0, ci0); + ker_shape.insert(0, co0 * group) + } + KernelFormat::OHWI => { + ker_shape.insert(0, co0); + ker_shape.push(ci0 * group) + } + }; + let kernel = qtensor(ker_shape); + let bias = proptest::option::of( + qtensor(vec![co0 * group]).prop_map(|a| a.mapv(|v| v as i32)), + ); + (Just((kf, shape_in, co0 * group, group)), data_in, kernel, bias, qp) + // FIXME + }, + ) + .prop_map(|((kernel_format, shape_in, co, group), data, kernel, bias, qp)| { QConvProblem { shape_in, co, kernel_format, group, data, kernel, bias, qp } }) .boxed() } } -/* -proptest::proptest! { - #[test] - fn prop(pb in any::()) { - pb.check().unwrap() - } -} -*/ - fn qp_noop_i8() -> [Tensor; 6] { [tensor0(0i8), tensor0(1f32), tensor0(0i8), tensor0(1f32), tensor0(0i8), tensor0(1f32)] } @@ -227,6 +310,8 @@ fn qp_noop_i8() -> [Tensor; 6] { pub fn suite() -> TractResult { let mut suite = TestSuite::default(); + suite.add_arbitrary::("proptest", QConvProblemParams::default()); + suite.add( "trivial_0", QConvProblem { @@ -279,6 +364,53 @@ pub fn suite() -> TractResult { qp: qp_noop_i8(), }, ); + let mut qp = qp_noop_i8(); + qp[1] = tensor1(&[1f32, 0.5]); + qp[2] = tensor0(-2i8); + suite.add( + "weird_4", + QConvProblem { + shape_in: CHW.from_n_c_hw(1, 1, [1]).unwrap(), + kernel_format: OIHW, + co: 2, + group: 1, + data: arr2(&[[0i8]]).into_dyn(), + kernel: arr3(&[[[0i8]], [[7]]]).into_dyn(), + bias: None, + qp, + }, + ); + + suite.add( + "a0_0", + QConvProblem { + shape_in: HWC.from_n_c_hw(1, 1, [1]).unwrap(), + co: 1, + kernel_format: OIHW, + group: 1, + data: arr2(&[[1]]).into_dyn(), + kernel: arr3(&[[[0]]]).into_dyn(), + bias: None, + qp: qp_noop_i8(), + }, + ); + + let mut qp = qp_noop_i8(); + qp[0] = tensor0(1i32); + suite.add( + "kernel_zp", + QConvProblem { + shape_in: CHW.from_n_c_hw(1, 1, [1]).unwrap(), + kernel_format: OIHW, + co: 1, + group: 1, + data: arr2(&[[1i8]]).into_dyn(), + kernel: arr3(&[[[0i8]]]).into_dyn(), + bias: None, + qp, + }, + ); + let mut qp = qp_noop_i8(); qp[2] = tensor0(1i32); suite.add( @@ -336,19 +468,6 @@ pub fn suite() -> TractResult { qp, }, ); - suite.add( - "a0_0", - QConvProblem { - shape_in: HWC.from_n_c_hw(1, 1, [1]).unwrap(), - co: 1, - kernel_format: OIHW, - group: 1, - data: arr2(&[[1]]).into_dyn(), - kernel: arr3(&[[[0]]]).into_dyn(), - bias: None, - qp: qp_noop_i8(), - }, - ); let mut qp = qp_noop_i8(); qp[5] = tensor0(9.274534f32); suite.add( @@ -468,6 +587,26 @@ pub fn suite() -> TractResult { qp, }, ); + + let mut qp = qp_noop_i8(); + qp[1] = tensor0(0.5f32); + qp[2] = tensor0(2i32); + qp[3] = tensor0(2f32); + qp[5] = tensor0(2f32); + suite.add( + "rounding_0", + QConvProblem { + shape_in: CHW.from_n_c_hw(1, 1, [1]).unwrap(), + co: 1, + kernel_format: OIHW, + group: 1, + data: arr2(&[[4i8]]).into_dyn(), + kernel: arr3(&[[[-5]]]).into_dyn(), + bias: Some(arr1(&[-125i32]).into_dyn()), + qp, + }, + ); + let mut qp = qp_noop_i8(); qp[5] = tensor0(1.3759452f32); suite.add( @@ -483,9 +622,7 @@ pub fn suite() -> TractResult { qp, }, ); - let qp = qp_noop_i8(); - let data = ArrayD::zeros(vec![1, 1, 1, 1]); - let kernel = ArrayD::zeros(vec![2, 1, 1, 1]); + suite.add( "bias_1", QConvProblem { @@ -493,12 +630,13 @@ pub fn suite() -> TractResult { co: 2, kernel_format: OIHW, group: 1, - data, - kernel, + data: ArrayD::zeros(vec![1, 1, 1, 1]), + kernel: ArrayD::zeros(vec![2, 1, 1, 1]), bias: Some(tract_ndarray::arr1(&[1, 2]).into_dyn()), - qp, + qp: qp_noop_i8(), }, ); + let qp = qp_noop_i8(); let data = ArrayD::zeros(vec![1, 1]); let kernel = ArrayD::zeros(vec![2, 1, 1]); @@ -515,6 +653,7 @@ pub fn suite() -> TractResult { qp, }, ); + let mut qp = qp_noop_i8(); qp[2] = tensor0(-1); let data = ArrayD::zeros(vec![2, 1]); @@ -533,6 +672,35 @@ pub fn suite() -> TractResult { qp, }, ); + + suite.add( + "bias_4", + QConvProblem { + shape_in: NHWC.from_n_c_hw(1, 1, [1, 1]).unwrap(), + co: 2, + kernel_format: OIHW, + group: 1, + data: ArrayD::zeros(vec![1, 1, 1, 1]), + kernel: ArrayD::zeros(vec![2, 1, 1, 1]), + bias: Some(tract_ndarray::arr1(&[0, 1]).into_dyn()), + qp: qp_noop_i8(), + }, + ); + + suite.add( + "bias_5", + QConvProblem { + shape_in: NHWC.from_n_c_hw(1, 1, [1, 1]).unwrap(), + co: 1, + kernel_format: OIHW, + group: 1, + data: ArrayD::zeros(vec![1, 1, 1, 1]), + kernel: ArrayD::zeros(vec![1, 1, 1, 1]), + bias: Some(tract_ndarray::arr1(&[1]).into_dyn()), + qp: qp_noop_i8(), + }, + ); + let qp = qp_noop_i8(); let data = ArrayD::zeros(vec![1, 1]); let kernel = ArrayD::zeros(vec![2, 1, 1]); @@ -597,5 +765,65 @@ pub fn suite() -> TractResult { qp, }, ); + let mut qp = qp_noop_i8(); + qp[1] = tensor1(&[1f32, 1f32]); + suite.add( + "tflite_per_axis_0", + QConvProblem { + shape_in: CHW.from_n_c_hw(1, 1, [1]).unwrap(), + co: 2, + kernel_format: OIHW, + group: 1, + data: ArrayD::zeros(vec![1, 1]), + kernel: ArrayD::zeros(vec![2, 1, 1]), + bias: None, + qp, + }, + ); + let mut qp = qp_noop_i8(); + qp[1] = tensor1(&[1f32, 1f32]); + suite.add( + "tflite_per_axis_1", + QConvProblem { + shape_in: CHW.from_n_c_hw(1, 1, [1, 2]).unwrap(), + co: 2, + kernel_format: OIHW, + group: 1, + data: ArrayD::zeros(vec![1, 1, 2]), + kernel: ArrayD::zeros(vec![2, 1, 1, 2]), + bias: None, + qp, + }, + ); + let mut qp = qp_noop_i8(); + qp[1] = tensor1(&[1f32, 1f32]); + suite.add( + "tflite_per_axis_nchw_0", + QConvProblem { + shape_in: NCHW.from_n_c_hw(1, 1, [1]).unwrap(), + co: 2, + kernel_format: OIHW, + group: 1, + data: ArrayD::zeros(vec![1, 1, 1]), + kernel: ArrayD::zeros(vec![2, 1, 1]), + bias: None, + qp, + }, + ); + let mut qp = qp_noop_i8(); + qp[1] = tensor1(&[1f32, 1f32]); + suite.add( + "tflite_per_axis_nchw_1", + QConvProblem { + shape_in: NCHW.from_n_c_hw(1, 1, [2]).unwrap(), + co: 2, + kernel_format: OIHW, + group: 1, + data: ArrayD::zeros(vec![1, 1, 2]), + kernel: ArrayD::zeros(vec![2, 1, 2]), + bias: None, + qp, + }, + ); Ok(suite) } diff --git a/test-rt/suite-conv/src/deconv.rs b/test-rt/suite-conv/src/deconv.rs index 6eb7668902..2d4cf96bde 100644 --- a/test-rt/suite-conv/src/deconv.rs +++ b/test-rt/suite-conv/src/deconv.rs @@ -270,12 +270,14 @@ impl DeconvProblem { } impl Test for DeconvProblem { - fn run(&self, runtime: &dyn Runtime) -> TestResult { + fn run_with_approx(&self, id: &str, runtime: &dyn Runtime, approx: Approximation) -> TestResult { let reference = self.reference().into_tensor(); + let mut model = self.tract()?; + model.properties.insert("tract-rt-test.id".to_string(), rctensor0(id.to_string())); let mut output = - runtime.prepare(self.tract()?)?.run(tvec![self.input.clone().into_tvalue()])?; + runtime.prepare(model)?.run(tvec![self.input.clone().into_tvalue()])?; let output = output.remove(0).into_tensor(); - output.close_enough(&reference, true) + output.close_enough(&reference, approx) } } diff --git a/test-rt/suite-onnx/src/lib.rs b/test-rt/suite-onnx/src/lib.rs index e3de3fb7ae..808cacf401 100644 --- a/test-rt/suite-onnx/src/lib.rs +++ b/test-rt/suite-onnx/src/lib.rs @@ -27,7 +27,7 @@ struct OnnxTestCase { } impl Test for OnnxTestCase { - fn run(&self, runtime: &dyn Runtime) -> TractResult<()> { + fn run_with_approx(&self, _id: &str, runtime: &dyn Runtime, approx: Approximation) -> TractResult<()> { setup_test_logger(); let model_file = self.path.join("model.onnx"); info!("Loading {:?}", model_file); @@ -78,7 +78,7 @@ impl Test for OnnxTestCase { let model = model.into_typed()?.into_decluttered()?; info!("Test model (mode: {}) {:#?}", runtime.name(), self.path); let runnable = runtime.prepare(model)?; - run_model(&*runnable, inputs, &data_path); + run_model(&*runnable, inputs, &data_path, approx); info!("Test model (mode: {}) {:#?} OK.", runtime.name(), self.path); } } @@ -139,10 +139,7 @@ pub fn ensure_onnx_git_checkout() { for (v, _) in versions() { let wanted = dir().join(format!("onnx-{}", v.replace('.', "_"))); if !wanted.join("onnx/backend/test/data").exists() { - let df = std::process::Command::new("df") - .arg("-h") - .output() - .unwrap(); + let df = std::process::Command::new("df").arg("-h").output().unwrap(); dbg!(df); let tmp = wanted.with_extension("tmp"); let _ = std::fs::remove_dir_all(&wanted); @@ -263,7 +260,12 @@ pub fn load_half_dataset(prefix: &str, path: &std::path::Path) -> TVec { vec } -fn run_model(model: &dyn Runnable, inputs: TVec, data_path: &std::path::Path) { +fn run_model( + model: &dyn Runnable, + inputs: TVec, + data_path: &std::path::Path, + approx: Approximation, +) { let expected = load_half_dataset("output", data_path); trace!("Loaded output asserts: {:?}", expected); let inputs = inputs.into_iter().map(|t| t.into_tvalue()).collect(); @@ -279,9 +281,9 @@ fn run_model(model: &dyn Runnable, inputs: TVec, data_path: &std::path:: for (ix, (a, b)) in computed.iter().zip(expected.iter()).enumerate() { // println!("computed: {:?}", computed[ix].dump(true)); // println!("expected: {:?}", expected[ix].dump(true)); - if let Err(e) = a.close_enough(b, true) { + if let Err(e) = a.close_enough(b, approx) { panic!( - "For {:?}, different result for output #{}:\ngot:\n{:?}\nexpected:\n{:?}\n{}", + "For {:?}, different ({approx:?}) result for output #{}:\ngot:\n{:?}\nexpected:\n{:?} \n{}", data_path, ix, a.cast_to::().unwrap().to_array_view::().unwrap(), diff --git a/test-rt/test-conv-core/build.rs b/test-rt/test-conv-core/build.rs index b63fffc428..6d893e02d4 100644 --- a/test-rt/test-conv-core/build.rs +++ b/test-rt/test-conv-core/build.rs @@ -1,5 +1,15 @@ fn main() { let suite = suite_conv::suite().unwrap(); - suite.test_runtime("default", "suite_conv::suite().unwrap()", "default()"); - suite.test_runtime("unoptimized", "suite_conv::suite().unwrap()", "unoptimized()"); + suite.test_runtime( + "default", + "suite_conv::suite().unwrap()", + "default()", + "Approximation::Approximate", + ); + suite.test_runtime( + "unoptimized", + "suite_conv::suite().unwrap()", + "unoptimized()", + "Approximation::Approximate", + ); } diff --git a/test-rt/test-onnx-core/build.rs b/test-rt/test-onnx-core/build.rs index 5661fbecfc..ad47c9b621 100644 --- a/test-rt/test-onnx-core/build.rs +++ b/test-rt/test-onnx-core/build.rs @@ -1,5 +1,10 @@ fn main() { let suite = suite_onnx::suite(); - suite.test_runtime("default", "suite_onnx::suite()", "default()"); - suite.test_runtime("unoptimized", "suite_onnx::suite()", "unoptimized()"); + suite.test_runtime("default", "suite_onnx::suite()", "default()", "Approximation::Approximate"); + suite.test_runtime( + "unoptimized", + "suite_onnx::suite()", + "unoptimized()", + "Approximation::Approximate", + ); } diff --git a/test-rt/test-onnx-nnef-cycle/build.rs b/test-rt/test-onnx-nnef-cycle/build.rs index 91e1ec30db..fb5230e1d6 100644 --- a/test-rt/test-onnx-nnef-cycle/build.rs +++ b/test-rt/test-onnx-nnef-cycle/build.rs @@ -1,7 +1,7 @@ fn main() { let mut suite = suite_onnx::suite().clone(); suite.ignore(&ignore); - suite.test_runtime("nnef_cycle", "suite_onnx::suite()", "nnef_cycle()"); + suite.test_runtime("nnef_cycle", "suite_onnx::suite()", "nnef_cycle()", "Approximation::Approximate"); } fn ignore(t: &[String]) -> bool { diff --git a/test-rt/test-tflite/build.rs b/test-rt/test-tflite/build.rs index 21c3721318..09d4238f2d 100644 --- a/test-rt/test-tflite/build.rs +++ b/test-rt/test-tflite/build.rs @@ -2,6 +2,6 @@ mod suite; fn main() { - suite::suite().test_runtime("tests", "suite::suite()", "runtime()"); + suite::suite().test_runtime("tests", "suite::suite()", "runtime()", "Approximation::Approximate"); } diff --git a/test-rt/test-tflite/src/tflite_runtime.rs b/test-rt/test-tflite/src/tflite_runtime.rs index 33a8a438fa..40398116ee 100644 --- a/test-rt/test-tflite/src/tflite_runtime.rs +++ b/test-rt/test-tflite/src/tflite_runtime.rs @@ -14,12 +14,16 @@ impl Runtime for TfliteRuntime { fn prepare(&self, model: TypedModel) -> TractResult> { let mut buffer = vec![]; self.0.write(&model, &mut buffer)?; - // std::fs::write("foo.tflite", &buffer).unwrap(); - Ok(Box::new(TfliteRunnable(buffer))) + let output_dt = model + .output_outlets()? + .iter() + .map(|oo| model.outlet_fact(*oo).unwrap().datum_type) + .collect(); + Ok(Box::new(TfliteRunnable(buffer, output_dt))) } } -struct TfliteRunnable(Vec); +struct TfliteRunnable(Vec, TVec); impl Runnable for TfliteRunnable { fn spawn(&self) -> TractResult> { @@ -28,11 +32,11 @@ impl Runnable for TfliteRunnable { let builder = InterpreterBuilder::new(fb, resolver)?; let mut interpreter = builder.build()?; interpreter.allocate_tensors()?; - Ok(Box::new(TfliteState(interpreter))) + Ok(Box::new(TfliteState(interpreter, self.1.clone()))) } } -struct TfliteState(Interpreter<'static, BuiltinOpResolver>); +struct TfliteState(Interpreter<'static, BuiltinOpResolver>, TVec); impl State for TfliteState { fn run(&mut self, inputs: TVec) -> TractResult> { @@ -40,7 +44,6 @@ impl State for TfliteState { for (ix, input) in inputs.iter().enumerate() { let input_ix = self.0.inputs()[ix]; let input_tensor = self.0.tensor_info(input_ix).unwrap(); - assert_eq!(input_tensor.element_kind as u32, 1); // 1 is f32 assert_eq!(input_tensor.dims, input.shape()); self.0.tensor_buffer_mut(input_ix).unwrap().copy_from_slice(unsafe { input.as_bytes() }) } @@ -49,10 +52,15 @@ impl State for TfliteState { for ix in 0..self.0.outputs().len() { let output_ix = self.0.outputs()[ix]; let output_tensor = self.0.tensor_info(output_ix).unwrap(); - assert_eq!(output_tensor.element_kind as u32, 1); + let dt = match output_tensor.element_kind as u32 { + 1 => f32::datum_type(), + 9 => self.1[ix].clone(), // impossible to retrieve QP from this TFL binding + _ => bail!("unknown type"), + }; + dbg!(self.0.tensor_buffer(output_ix)); let tensor = unsafe { Tensor::from_raw_dt( - f32::datum_type(), + dt, &output_tensor.dims, self.0.tensor_buffer(output_ix).unwrap(), )? diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 4df8eba4f8..99e2779555 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -1,13 +1,18 @@ use suite_conv::conv_f32::{ConvProblem, ConvProblemParams}; +use suite_conv::conv_q::{QConvProblem, QConvProblemParams}; +#[allow(clippy::needless_update)] pub fn suite() -> infra::TestSuite { let mut onnx = suite_onnx::suite().clone(); onnx.ignore(&ignore_onnx); let mut conv = suite_conv::suite().unwrap().clone(); conv.ignore(&ignore_conv); - conv.add_arbitrary::( + let cv = + ConvProblemParams { no_group: true, geo_rank: Some(1..3), ..ConvProblemParams::default() }; + conv.get_sub_mut("f32").add_arbitrary::("proptest", cv.clone()); + conv.get_sub_mut("q").add_arbitrary::( "proptest", - ConvProblemParams { no_group: true, geo_rank: Some(1..3), ..ConvProblemParams::default() }, + QConvProblemParams { conv: cv, no_kernel_zero_point: true, .. QConvProblemParams::default() }, ); infra::TestSuite::default().with("onnx", onnx).with("conv", conv) } @@ -28,8 +33,7 @@ fn ignore_onnx(t: &[String]) -> bool { fn ignore_conv(t: &[String]) -> bool { let [section, unit] = t else { return false }; - section == "q" || section == "deconv" - || unit == "proptest" + section == "deconv" // grouping and depthwise || unit.starts_with("group") // conv 3D @@ -37,4 +41,6 @@ fn ignore_conv(t: &[String]) -> bool { || unit == "lazy_im2col_big_2" || unit == "batch_3d" || unit == "bias_3d_1" + // kernel with non 0 zero_point + || unit == "kernel_zp" } diff --git a/tflite/src/model.rs b/tflite/src/model.rs index 2c3539515d..454469ecd1 100644 --- a/tflite/src/model.rs +++ b/tflite/src/model.rs @@ -4,6 +4,7 @@ use flatbuffers::FlatBufferBuilder; use tract_hir::internal::*; use crate::registry::Registry; +use crate::tensors::{flat_tensor_to_tract_fact, flat_tensor_uses_per_axis_q}; use crate::tflite; use crate::tflite::{Buffer, BufferArgs}; @@ -80,21 +81,24 @@ impl Framework for Tflite { let mut target = TypedModel::default(); let mut mapping = HashMap::new(); for input in main.inputs().context("No inputs in Tflite model")? { - let (fact, name) = crate::tensors::flat_tensor_to_tract_fact(&root, main, input)?; - let it = target.add_source(name, fact)?; - mapping.insert(input, it); + if !flat_tensor_uses_per_axis_q(main, input) { + let (fact, name) = flat_tensor_to_tract_fact(&root, main, input)?; + let it = target.add_source(name, fact)?; + mapping.insert(input, it); + } } for op in main.operators().context("No operators in Tflite model")? { for input in op.inputs().context("No input in Tflite operator")? { if let Entry::Vacant(slot) = mapping.entry(input) { - let (fact, name) = - crate::tensors::flat_tensor_to_tract_fact(&root, main, input)?; + let (fact, name) = flat_tensor_to_tract_fact(&root, main, input)?; let value = fact.konst.with_context(|| format!("Error in TF file for operator {:?}. No prior computation nor constant for input {}", op, input))?; let konst = target.add_const(name, value)?; slot.insert(konst); } } - self.0.op(&root, main, &op, &mut target, &mut mapping)?; + self.0 + .op(&root, main, &op, &mut target, &mut mapping) + .with_context(|| format!("Parsing op {op:#?}"))?; } let outputs: TVec<_> = main .outputs() diff --git a/tflite/src/ops/array.rs b/tflite/src/ops/array.rs index b7be7ce2d8..b15089dbba 100644 --- a/tflite/src/ops/array.rs +++ b/tflite/src/ops/array.rs @@ -15,14 +15,15 @@ use super::wire_fused_activation; pub fn register_all(reg: &mut Registry) { reg.reg_to_tflite::(ser_axisop); - reg.to_tract.insert(BuiltinOperator::CONCATENATION, de_concat); - reg.to_tract.insert(BuiltinOperator::EXPAND_DIMS, de_expand_dims); - reg.to_tract.insert(BuiltinOperator::PADV2, de_padv2); - reg.to_tract.insert(BuiltinOperator::RESHAPE, de_reshape); - reg.to_tract.insert(BuiltinOperator::SHAPE, de_shape); - reg.to_tract.insert(BuiltinOperator::SQUEEZE, de_squeeze); - reg.to_tract.insert(BuiltinOperator::STRIDED_SLICE, de_strided_slice); - reg.to_tract.insert(BuiltinOperator::TRANSPOSE, de_transpose); + reg.reg_to_tract(BuiltinOperator::CONCATENATION, de_concat); + reg.reg_to_tract(BuiltinOperator::EXPAND_DIMS, de_expand_dims); + reg.reg_to_tract(BuiltinOperator::PAD, de_pad); + reg.reg_to_tract(BuiltinOperator::PADV2, de_padv2); + reg.reg_to_tract(BuiltinOperator::RESHAPE, de_reshape); + reg.reg_to_tract(BuiltinOperator::SHAPE, de_shape); + reg.reg_to_tract(BuiltinOperator::SQUEEZE, de_squeeze); + reg.reg_to_tract(BuiltinOperator::STRIDED_SLICE, de_strided_slice); + reg.reg_to_tract(BuiltinOperator::TRANSPOSE, de_transpose); } fn de_concat(op: &mut DeserOp) -> TractResult> { @@ -49,6 +50,19 @@ fn de_expand_dims(op: &mut DeserOp) -> TractResult> { Ok(wire) } +fn de_pad(op: &mut DeserOp) -> TractResult> { + let (input, pads) = args_2!(op.facts()?); + let pads = pads.konst.as_ref().context("Dynamic PAD is not supported")?; + let prefix = op.prefix; + let pads: ArrayView2 = pads.to_array_view::()?.into_dimensionality()?; + let pads: Vec<(usize, usize)> = + pads.rows().into_iter().map(|row| (row[0] as usize, row[1] as usize)).collect(); + let mode = tract_hir::ops::array::PadMode::Constant( + Tensor::zero_scalar_dt(input.datum_type)?.into(), + ); + op.ctx.target.wire_node(prefix, tract_core::ops::array::Pad { pads, mode }, &op.inputs[0..1]) +} + fn de_padv2(op: &mut DeserOp) -> TractResult> { let (_input, pads, value) = args_3!(op.facts()?); let pads = pads.konst.as_ref().context("Dynamic PADV2 is not supported")?; diff --git a/tflite/src/ops/cnn.rs b/tflite/src/ops/cnn.rs index 34a99b6304..097a27758c 100644 --- a/tflite/src/ops/cnn.rs +++ b/tflite/src/ops/cnn.rs @@ -14,10 +14,10 @@ use tract_hir::tract_core::ops as core; use tract_hir::tract_core::ops::cnn::KernelFormat; pub fn register_all(reg: &mut Registry) { - reg.to_tract.insert(BuiltinOperator::AVERAGE_POOL_2D, average_pool_2d); - reg.to_tract.insert(BuiltinOperator::CONV_2D, conv2d); + reg.reg_to_tract(BuiltinOperator::AVERAGE_POOL_2D, average_pool_2d); + reg.reg_to_tract(BuiltinOperator::CONV_2D, de_conv2d); reg.reg_to_tflite::(ser_conv); - reg.to_tract.insert(BuiltinOperator::DEPTHWISE_CONV_2D, dw_conv2d); + reg.reg_to_tract(BuiltinOperator::DEPTHWISE_CONV_2D, de_dw_conv2d); reg.reg_to_tflite::(ser_pad); } @@ -58,17 +58,47 @@ fn ser_conv( || conv.pool_spec.padding == PaddingSpec::SameUpper ); let node_name = &node.name; - let mut inputs = node.inputs.iter().map(|o| builder.outlets_to_tensors[o]).collect_vec(); - let outputs = (0..node.outputs.len()) - .map(|o| builder.outlets_to_tensors[&OutletId::new(node.id, o)]) - .collect_vec(); - inputs.push(builder.write_fact(&format!("{node_name}.weights"), &conv.kernel)?); - inputs.push(builder.write_fact( - &format!("{node_name}.bias"), - &conv.bias.clone().unwrap_or_else(|| { - rctensor1(&vec![0f32; conv.pool_spec.output_channel_override.unwrap()]) - }), - )?); + let mut inputs = tvec!(builder.outlets_to_tensors[&node.inputs[0]]); + let mut bias = conv.bias.clone().unwrap_or_else(|| { + let co = conv.pool_spec.output_channel_override.unwrap(); + if conv.q_params.is_some() { + Tensor::zero::(&[co]).unwrap() + } else { + Tensor::zero::(&[co]).unwrap() + } + .into_arc_tensor() + }); + if conv.q_params.is_some() { + ensure!(conv.kernel.datum_type().zp_scale().0 == 0); + let iscale = model.node_input_facts(node.id).unwrap()[0].datum_type.zp_scale().1; + let kscale_tract = model.node_input_facts(node.id).unwrap()[2] + .konst + .as_ref() + .unwrap() + .as_slice::()?; + let co = conv.kernel.shape()[0]; + let kscale = + if kscale_tract.len() == 1 { vec![kscale_tract[0]; co] } else { kscale_tract.to_vec() }; + inputs.push(builder.write_fact_with_per_axis_q( + &format!("{node_name}.weights"), + &conv.kernel, + &vec![0i64; conv.kernel.shape()[0]], + &kscale, + 0, + )?); + let bias_dt = i32::datum_type().quantize(QParams::ZpScale { zero_point: 0, scale: iscale }); + bias = bias.clone().into_tensor().cast_to_dt(bias_dt)?.into_owned().into_arc_tensor(); + inputs.push(builder.write_fact_faking_per_axis_q( + &format!("{node_name}.bias"), + &bias, + 0, + )?); + } else { + inputs.push(builder.write_fact(&format!("{node_name}.weights"), &conv.kernel)?); + inputs.push(builder.write_fact(&format!("{node_name}.bias"), &bias)?); + } + let output = builder.outlets_to_tensors[&node.id.into()]; + let padding = if conv.pool_spec.padding == PaddingSpec::Valid { Padding::VALID } else { Padding::SAME }; if conv.group == 1 { @@ -85,7 +115,7 @@ fn ser_conv( ); builder.write_op_with_options( &inputs, - &outputs, + &[output], BuiltinOp::new(3, 2, BuiltinOperator::CONV_2D, BuiltinOptions::Conv2DOptions), options.as_union_value(), ) @@ -106,7 +136,7 @@ fn ser_conv( ); builder.write_op_with_options( &inputs, - &outputs, + &[output], BuiltinOp::new( 4, 2, @@ -118,8 +148,8 @@ fn ser_conv( } } -fn conv2d(op: &mut DeserOp) -> TractResult> { - let (_input, kernel, bias) = args_3!(op.facts()?); +fn de_conv2d(op: &mut DeserOp) -> TractResult> { + let (input, kernel, bias) = args_3!(op.facts()?); let kernel = kernel.konst.unwrap(); let bias = bias.konst.unwrap(); let kernel_full_shape: TVec = kernel.shape().into(); @@ -142,20 +172,24 @@ fn conv2d(op: &mut DeserOp) -> TractResult> { dilations: Some(dilations), output_channel_override: Some(*co), }; + let mut inputs = tvec!(op.inputs[0]); + let q_params = super::linearops_quantization_suport(op, &input, &mut inputs)?; + let bias_dt = bias.datum_type().unquantized(); + let bias = bias.into_tensor().cast_to_dt(bias_dt)?.into_owned().into_arc_tensor(); let conv = core::cnn::ConvUnary { pool_spec, kernel_fmt: KernelFormat::OHWI, kernel, group: 1, bias: Some(bias), - q_params: None, + q_params, }; - let wires = op.ctx.target.wire_node(op.prefix, conv, &op.inputs[0..1])?; + let wires = op.ctx.target.wire_node(op.prefix, conv, &inputs)?; wire_fused_activation(op, &wires, &options.fused_activation_function()) } -fn dw_conv2d(op: &mut DeserOp) -> TractResult> { - let (_input, kernel, bias) = args_3!(op.facts()?); +fn de_dw_conv2d(op: &mut DeserOp) -> TractResult> { + let (input, kernel, bias) = args_3!(op.facts()?); let bias = bias.konst.unwrap(); let kernel = kernel.konst.unwrap(); let kernel_full_shape: TVec = kernel.shape().into(); @@ -178,15 +212,17 @@ fn dw_conv2d(op: &mut DeserOp) -> TractResult> { dilations: Some(dilations), output_channel_override: Some(co), }; + let mut inputs = tvec!(op.inputs[0]); + let q_params = super::linearops_quantization_suport(op, &input, &mut inputs)?; let conv = core::cnn::ConvUnary { pool_spec, kernel_fmt: KernelFormat::OHWI, kernel, group: co, bias: Some(bias), - q_params: None, + q_params, }; - let wires = op.ctx.target.wire_node(op.prefix, conv, &op.inputs[0..1])?; + let wires = op.ctx.target.wire_node(op.prefix, conv, &inputs)?; wire_fused_activation(op, &wires, &options.fused_activation_function()) } diff --git a/tflite/src/ops/math.rs b/tflite/src/ops/math.rs index a3563b4263..c0e3d8acea 100644 --- a/tflite/src/ops/math.rs +++ b/tflite/src/ops/math.rs @@ -3,9 +3,9 @@ use crate::tflite::BuiltinOperator; use tract_hir::tract_core::ops as core; pub fn register_all(reg: &mut Registry) { - reg.binary_ops.push((BuiltinOperator::ADD, core::math::add())); - reg.binary_ops.push((BuiltinOperator::SUB, core::math::sub())); - reg.binary_ops.push((BuiltinOperator::MUL, core::math::mul())); - reg.binary_ops.push((BuiltinOperator::DIV, core::math::div())); + reg.reg_binary(BuiltinOperator::ADD, core::math::add()); + reg.reg_binary(BuiltinOperator::SUB, core::math::sub()); + reg.reg_binary(BuiltinOperator::MUL, core::math::mul()); + reg.reg_binary(BuiltinOperator::DIV, core::math::div()); } diff --git a/tflite/src/ops/mod.rs b/tflite/src/ops/mod.rs index fc264e7443..7fabda1e08 100644 --- a/tflite/src/ops/mod.rs +++ b/tflite/src/ops/mod.rs @@ -1,4 +1,5 @@ use tract_hir::internal::*; +use tract_hir::prelude::tract_itertools::Itertools; use crate::registry::{DeserContext, DeserOp, Registry}; use crate::tflite::ActivationFunctionType; @@ -40,11 +41,39 @@ fn wire_fused_activation( prefix: &prefix, flat: op.flat, inputs: wires, + output_facts: op.output_facts, }; match *activation { ActivationFunctionType::NONE => Ok(wires.into()), - ActivationFunctionType::RELU => nn::relu(&mut op), - ActivationFunctionType::RELU6 => nn::relu6(&mut op), + ActivationFunctionType::RELU => nn::de_relu(&mut op), + ActivationFunctionType::RELU6 => nn::de_relu6(&mut op), af => bail!("Unsupported fused activation type: {af:?}"), } } + +fn linearops_quantization_suport( + op: &mut DeserOp, + input: &TypedFact, + inputs: &mut TVec, +) -> TractResult> { + if op.output_facts[0].datum_type.is_quantized() { + let p = &op.prefix; + let iqp = input.datum_type.qparams().unwrap(); + let oqp = op.output_facts[0].datum_type; + let k_input = op.flat.inputs().unwrap().get(1); + let k_tensor = op.ctx.subgraph.tensors().unwrap().get(k_input as usize); + let k_qp = k_tensor.quantization().unwrap(); + let kscale = k_qp.scale().unwrap().iter().collect_vec(); + let k_zp = k_qp.zero_point().unwrap().iter().map(|i| i as i32).collect_vec(); + ensure!(k_zp.iter().all(|x| *x == 0)); + inputs.push(op.ctx.target.add_const(format!("{p}.k0"), rctensor0(0i8))?); + inputs.push(op.ctx.target.add_const(format!("{p}.kscale"), rctensor1(&kscale))?); + inputs.push(op.ctx.target.add_const(format!("{p}.i0"), rctensor0(iqp.zp_scale().0 as i8))?); + inputs.push(op.ctx.target.add_const(format!("{p}.iscale"), rctensor0(iqp.zp_scale().1))?); + inputs.push(op.ctx.target.add_const(format!("{p}.c0"), rctensor0(oqp.zp_scale().0 as i8))?); + inputs.push(op.ctx.target.add_const(format!("{p}.cscale"), rctensor0(oqp.zp_scale().1))?); + Ok(Some(oqp)) + } else { + Ok(None) + } +} diff --git a/tflite/src/ops/nn.rs b/tflite/src/ops/nn.rs index 7018e7516c..0912650c81 100644 --- a/tflite/src/ops/nn.rs +++ b/tflite/src/ops/nn.rs @@ -1,38 +1,80 @@ use tract_hir::internal::*; use tract_hir::ops::binary::wire_cast; use tract_hir::ops::logic::wire_with_rank_broadcast; +use tract_hir::prelude::tract_itertools::Itertools; use tract_hir::tract_core::ops as core; +use tract_hir::tract_core::ops::cast::Cast; +use tract_hir::tract_core::ops::einsum::EinSum; use crate::registry::{DeserOp, Registry}; -use crate::tflite::BuiltinOperator; +use crate::tflite::{BuiltinOperator, FullyConnectedOptionsWeightsFormat}; pub fn register_all(reg: &mut Registry) { - reg.to_tract.insert(BuiltinOperator::MEAN, reduce_mean); - reg.to_tract.insert(BuiltinOperator::SOFTMAX, softmax); + reg.reg_to_tract(BuiltinOperator::FULLY_CONNECTED, de_fully_connected); + reg.reg_to_tract(BuiltinOperator::MEAN, de_reduce_mean); + reg.reg_to_tract(BuiltinOperator::SOFTMAX, de_softmax); - reg.to_tract.insert(BuiltinOperator::RELU, relu); - reg.to_tract.insert(BuiltinOperator::RELU6, relu6); - reg.element_wise_ops.push((BuiltinOperator::HARD_SWISH, Box::new(core::nn::HardSwish {}))); + reg.reg_to_tract(BuiltinOperator::RELU, de_relu); + reg.reg_to_tract(BuiltinOperator::RELU6, de_relu6); + reg.reg_element_wise(BuiltinOperator::HARD_SWISH, Box::new(core::nn::HardSwish {})); } -fn reduce_mean(op: &mut DeserOp) -> TractResult> { +fn de_fully_connected(op: &mut DeserOp) -> TractResult> { + let (input, _weights, _bias) = args_3!(op.facts()?); + let options = builtin!(op, builtin_options_as_fully_connected_options); + ensure!(input.datum_type.is_quantized()); + ensure!(options.weights_format() == FullyConnectedOptionsWeightsFormat::DEFAULT); + ensure!(!options.keep_num_dims()); + ensure!(!options.asymmetric_quantize_inputs()); + let mut inputs: TVec = op.inputs.into(); + let qp = super::linearops_quantization_suport( + op, + &input, + &mut inputs, + )?; + let operating_dt = + if input.datum_type.is_float() { input.datum_type } else { i32::datum_type() }; + let einsum = EinSum { axes: "BI,OI,I,,,,,,->BO".parse()?, q_params: qp, operating_dt }; + let wires = op.ctx.target.wire_node(op.prefix, einsum, &inputs)?; + super::wire_fused_activation(op, &wires, &options.fused_activation_function()) +} + +fn de_reduce_mean(op: &mut DeserOp) -> TractResult> { let (input, axes) = args_2!(op.facts()?); let options = builtin!(op, builtin_options_as_reducer_options); - ensure!(options.keep_dims()); - let axes: TVec = - axes.konst.as_ref().unwrap().as_slice::()?.iter().map(|d| *d as usize).collect(); + let axes: TVec = axes + .konst + .as_ref() + .unwrap() + .as_slice::()? + .iter() + .map(|d| *d as usize) + .sorted() + .collect(); let norm: TDim = axes.iter().map(|d| &input.shape[*d]).product(); - let wire = op.ctx.target.wire_node( - op.prefix.to_string() + ".sum", - core::nn::Reduce::new(axes, core::nn::Reducer::Sum), + let p = &op.prefix; + let mut wire = op.ctx.target.wire_node( + format!("{p}.sum"), + core::nn::Reduce::new(axes.clone(), core::nn::Reducer::Sum), &[op.inputs[0]], )?; - let norm = op.ctx.target.add_const("{prefix}.card", tensor0(norm))?; - let wires = wire_cast(op.prefix, op.ctx.target, &[wire[0], norm], input.datum_type)?; - wire_with_rank_broadcast(op.prefix, op.ctx.target, core::math::div(), &wires) + if !options.keep_dims() { + for axis in axes.iter().rev() { + wire = + op.ctx.target.wire_node(format!("{p}.rm_axis_{axis}"), AxisOp::Rm(*axis), &wire)?; + } + } + let norm = op.ctx.target.add_const(format!("{p}.card"), tensor0(norm))?; + let norm = op.ctx.target.wire_node( + format!("{p}.as_float"), + Cast { to: f32::datum_type() }, + &[norm], + )?; + let norm = op.ctx.target.wire_node(format!("{p}.recip"), core::math::recip(), &norm)?; + wire_with_rank_broadcast(op.prefix, op.ctx.target, core::quant::scale(), &[norm[0], wire[0]]) } -fn softmax(op: &mut DeserOp) -> TractResult> { +fn de_softmax(op: &mut DeserOp) -> TractResult> { let input = args_1!(op.facts()?); let options = builtin!(op, builtin_options_as_softmax_options); ensure!(options.beta() == 1.0); @@ -40,7 +82,7 @@ fn softmax(op: &mut DeserOp) -> TractResult> { op.ctx.target.wire_node(op.prefix, softmax, op.inputs) } -pub fn relu(op: &mut DeserOp) -> TractResult> { +pub fn de_relu(op: &mut DeserOp) -> TractResult> { let input = op.inputs[0]; let zero = op.ctx.target.add_const(format!("{}.zero", op.prefix), tensor0(0f32))?; let wires = wire_cast( @@ -57,8 +99,8 @@ pub fn relu(op: &mut DeserOp) -> TractResult> { ) } -pub fn relu6(op: &mut DeserOp) -> TractResult> { - let input = relu(op)?[0]; +pub fn de_relu6(op: &mut DeserOp) -> TractResult> { + let input = de_relu(op)?[0]; let six = op.ctx.target.add_const(format!("{}.six", op.prefix), tensor0(6f32))?; let wires = wire_cast( op.prefix, diff --git a/tflite/src/registry.rs b/tflite/src/registry.rs index a8a2c1f6d8..0791bf7894 100644 --- a/tflite/src/registry.rs +++ b/tflite/src/registry.rs @@ -15,9 +15,9 @@ pub type ToTflite = fn(&mut SubgraphBuilder, &TypedModel, &TypedNode) -> TractRe pub struct Registry { // pub primitives: HashMap, // pub unit_element_wise_ops: Vec<(Identifier, Box)>, - pub element_wise_ops: Vec<(BuiltinOperator, Box)>, - pub binary_ops: Vec<(BuiltinOperator, TypedBinOp)>, - pub to_tract: HashMap, + pub element_wise_ops: Vec<(i32, Box)>, + pub binary_ops: Vec<(i32, TypedBinOp)>, + pub to_tract: HashMap, pub to_tflite: HashMap, } @@ -32,6 +32,7 @@ pub struct DeserOp<'op> { pub prefix: &'op str, pub flat: &'op Operator<'op>, pub inputs: &'op [OutletId], + pub output_facts: &'op [TypedFact], } impl<'op> DeserOp<'op> { @@ -48,6 +49,18 @@ impl Registry { self.to_tflite.insert(std::any::TypeId::of::(), tflite); } + pub fn reg_to_tract(&mut self, op: BuiltinOperator, to: ToTract) { + self.to_tract.insert(op.0, to); + } + + pub fn reg_element_wise(&mut self, tflite: BuiltinOperator, tract: Box) { + self.element_wise_ops.push((tflite.0, tract)); + } + + pub fn reg_binary(&mut self, tflite: BuiltinOperator, tract: TypedBinOp) { + self.binary_ops.push((tflite.0, tract)); + } + pub fn op( &self, model: &Model, @@ -56,31 +69,42 @@ impl Registry { target: &mut TypedModel, mapping: &mut HashMap, ) -> TractResult<()> { - let mut inputs = tvec!(); - for input in flat.inputs().unwrap() { - inputs.push(mapping[&input]); - } + let inputs: TVec = flat.inputs().unwrap().iter().map(|o| mapping[&o]).collect(); let tensors = subgraph.tensors().unwrap(); let prefix = tensors.get(flat.outputs().unwrap().get(0) as usize).name().unwrap(); let opcode_index = flat.opcode_index(); - let opcode = model.operator_codes().unwrap().get(opcode_index as _).builtin_code(); + let operator_code = model.operator_codes().unwrap().get(opcode_index as _); + let opcode = if operator_code.deprecated_builtin_code() as i32 + == BuiltinOperator::PLACEHOLDER_FOR_GREATER_OP_CODES.0 + { + operator_code.builtin_code().0 + } else { + operator_code.deprecated_builtin_code() as i32 + }; let ctx = DeserContext { model, subgraph, target }; - if let Some(ew) = + let results = if let Some(ew) = self.element_wise_ops.iter().find(|bin| bin.0 == opcode).map(|pair| pair.1.clone()) { - inputs = target.wire_node(prefix, ElementWiseOp(ew.clone()), &inputs)?; + target.wire_node(prefix, ElementWiseOp(ew.clone()), &inputs)? } else if let Some(bin) = self.binary_ops.iter().find(|bin| bin.0 == opcode).map(|pair| pair.1.clone()) { - inputs = wire_with_rank_broadcast(prefix, target, bin, &inputs)?; + wire_with_rank_broadcast(prefix, target, bin, &inputs)? } else if let Some(op) = self.to_tract.get(&opcode) { - inputs = (op)(&mut DeserOp { ctx, prefix, flat, inputs: &inputs })? + let output_facts = flat + .outputs() + .unwrap() + .iter() + .map(|t| Ok(crate::tensors::flat_tensor_to_tract_fact(model, subgraph, t)?.0)) + .collect::>>()?; + (op)(&mut DeserOp { ctx, prefix, flat, inputs: &inputs, output_facts: &output_facts }) + .with_context(|| format!("Opcode is {operator_code:#?}"))? } else { let facts = inputs.iter().map(|o| target.outlet_fact(*o)).collect::>>()?; - bail!("Unsupported operator {opcode:?}, inputs: {facts:#?}") - } - for (flat, wire) in flat.outputs().unwrap().iter().zip(inputs.iter()) { + bail!("Unsupported: {operator_code:#?}, inputs: {facts:#?}") + }; + for (flat, wire) in flat.outputs().unwrap().iter().zip(results.iter()) { mapping.insert(flat, *wire); } Ok(()) diff --git a/tflite/src/rewriter.rs b/tflite/src/rewriter.rs index 2d08afc6f3..85be04ef1a 100644 --- a/tflite/src/rewriter.rs +++ b/tflite/src/rewriter.rs @@ -46,7 +46,7 @@ fn kernel_in_ohwi( new.kernel_fmt = KernelFormat::OHWI; new.kernel = kernel.into_arc_tensor(); let mut patch = TypedModelPatch::default(); - let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); + let mut wire = patch.taps(model, &node.inputs)?; wire = patch.wire_node(name, new, &wire)?; patch.shunt_outside(model, node.id.into(), wire[0])?; Ok(Some(patch)) @@ -63,8 +63,8 @@ fn force_n_axis( let mut new = conv.clone(); new.pool_spec.data_format = conv.pool_spec.data_format.with_n(); let mut patch = TypedModelPatch::default(); - let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); - wire = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &wire)?; + let mut wire = patch.taps(model, &node.inputs)?; + wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0]; wire = patch.wire_node(name, new, &wire)?; wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?; patch.shunt_outside(model, node.id.into(), wire[0])?; @@ -88,8 +88,8 @@ fn make_1d_2d( kernel.insert_axis(conv.kernel_fmt.h_axis() + 1)?; new.kernel = kernel.into_arc_tensor(); let mut patch = TypedModelPatch::default(); - let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); - wire = patch.wire_node(format!("{name}.add_dim"), AxisOp::Add(pos), &wire)?; + let mut wire = patch.taps(model, &node.inputs)?; + wire[0] = patch.wire_node(format!("{name}.add_dim"), AxisOp::Add(pos), &[wire[0]])?[0]; wire = patch.wire_node(name, new, &wire)?; wire = patch.wire_node(format!("{name}.rm_dim"), AxisOp::Rm(pos), &wire)?; patch.shunt_outside(model, node.id.into(), wire[0])?; @@ -117,8 +117,9 @@ fn nchw_to_nhwc( let shape = conv.pool_spec.data_format.shape(&fact.shape)?; let before = shape.c_axis(); let after = fact.rank() - 1; - let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); - wire = patch.wire_node(format!("{name}.nhwc"), AxisOp::Move(before, after), &wire)?; + let mut wire = patch.taps(model, &node.inputs)?; + wire[0] = + patch.wire_node(format!("{name}.nhwc"), AxisOp::Move(before, after), &[wire[0]])?[0]; wire = patch.wire_node(name, new, &wire)?; wire = patch.wire_node(format!("{name}.nchw"), AxisOp::Move(after, before), &wire)?; patch.shunt_outside(model, node.id.into(), wire[0])?; @@ -161,7 +162,7 @@ fn padding( } } let mut patch = TypedModelPatch::default(); - let mut wires = tvec!(patch.tap_model(model, node.inputs[0])?); + let mut wires = patch.taps(model, &node.inputs)?; let mut pads = vec![(0usize, 0usize); fact.rank()]; for (padding, axis) in actual.iter().zip(shape.hw_axes()) { pads[axis] = (padding.pad_before.to_usize()?, padding.pad_after.to_usize()?); diff --git a/tflite/src/ser.rs b/tflite/src/ser.rs index 6d4026d3fe..3956beb7b7 100644 --- a/tflite/src/ser.rs +++ b/tflite/src/ser.rs @@ -6,8 +6,8 @@ use tract_hir::tract_core::ops::source::TypedSource; use crate::registry::Registry; use crate::tflite::{ Buffer, BufferArgs, BuiltinOperator, BuiltinOptions, CustomOptionsFormat, Model, ModelArgs, - Operator, OperatorArgs, OperatorCode, OperatorCodeArgs, SubGraph, SubGraphArgs, Tensor, - TensorArgs, + Operator, OperatorArgs, OperatorCode, OperatorCodeArgs, QuantizationDetails, + QuantizationParameters, QuantizationParametersArgs, SubGraph, SubGraphArgs, Tensor, TensorArgs, }; use flatbuffers::{FlatBufferBuilder, UnionWIPOffset, WIPOffset}; @@ -105,6 +105,73 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> { &mut self, name: impl AsRef, fact: impl Into, + ) -> TractResult { + let fact = fact.into(); + if let Some(qp) = fact.datum_type.qparams() { + self.write_fact_with_per_axis_q( + name, + fact, + &[qp.zp_scale().0 as i64], + &[qp.zp_scale().1], + 0, + ) + } else { + self.write_fact_with_quantization(name, fact, None) + } + } + + pub fn write_fact_faking_per_axis_q( + &mut self, + name: impl AsRef, + fact: impl Into, + axis: usize, + ) -> TractResult { + let fact = fact.into(); + if let Some(qp) = fact.datum_type.qparams() { + let dim = fact.shape[axis].to_usize()?; + self.write_fact_with_per_axis_q( + name, + fact, + &vec![qp.zp_scale().0 as i64; dim], + &vec![qp.zp_scale().1; dim], + axis, + ) + } else { + self.write_fact_with_quantization(name, fact, None) + } + } + + pub fn write_fact_with_per_axis_q( + &mut self, + name: impl AsRef, + fact: impl Into, + zp: &[i64], + scale: &[f32], + axis: usize, + ) -> TractResult { + let fact = fact.into(); + let zero_point = self.fb().create_vector(zp); + let scale = self.fb().create_vector(scale); + let qp = QuantizationParameters::create( + self.fb(), + &QuantizationParametersArgs { + min: None, + max: None, + zero_point: Some(zero_point), + scale: Some(scale), + details: None, + details_type: QuantizationDetails::NONE, + quantized_dimension: axis as i32, + }, + ); + self.write_fact_with_quantization(name, fact, Some(qp)) + } + + pub fn write_fact_with_quantization( + &mut self, + name: impl AsRef, + fact: impl Into, + quantization: Option>, ) -> TractResult { let fact = fact.into(); let buffer = if let Some(k) = &fact.konst { @@ -124,7 +191,7 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> { name: Some(name), buffer, is_variable: false, - quantization: None, + quantization, shape: Some(shape), type_: fact.datum_type.try_into()?, sparsity: None, @@ -140,6 +207,10 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> { fn write_subgraph(&mut self, model: &TypedModel) -> TractResult<()> { for &node_id in &model.eval_order()? { let node = &model.nodes[node_id]; + // will serialize constants at the demand of operators only + if node.op_is::() { + continue; + } // create fb tensors for all outputs for (slot, output) in node.outputs.iter().enumerate() { let name = model @@ -151,8 +222,8 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> { let outlet = OutletId::new(node.id, slot); self.outlets_to_tensors.insert(outlet, tensor); } - // Source and Const are not reified - if node.op_is::() || node.op_is::() { + // Source inputs are not reified + if node.op_is::() { continue; } else if let Some(to_tflite) = self.model.registry.to_tflite.get(&(*(node.op)).type_id()) diff --git a/tflite/src/tensors.rs b/tflite/src/tensors.rs index f869b3ef3e..c6d840853a 100644 --- a/tflite/src/tensors.rs +++ b/tflite/src/tensors.rs @@ -37,7 +37,7 @@ impl TryFrom for DatumType { impl TryFrom for BufferTensorType { type Error = TractError; fn try_from(value: DatumType) -> Result { - Ok(match value { + Ok(match value.unquantized() { DatumType::Bool => BufferTensorType::BOOL, DatumType::U8 => BufferTensorType::UINT8, DatumType::U16 => BufferTensorType::UINT16, @@ -80,22 +80,60 @@ fn create_tensor(dt: DatumType, shape: &[usize], data: &[u8]) -> TractResult(graph: &'m SubGraph<'m>, id: i32) -> bool { + let flat = graph.tensors().unwrap().get(id as _); + if let Some(qp) = flat.quantization() { + if let (Some(scale), Some(zp)) = (qp.scale(), qp.zero_point()) { + return !scale.iter().all_equal() || !zp.iter().all_equal(); + } + } + false +} + +pub fn per_axis_q_params<'m>( + graph: &'m SubGraph<'m>, + id: i32, +) -> TractResult<(Vec, Vec)> { + let flat = graph.tensors().unwrap().get(id as _); + let Some(qp) = flat.quantization() else { bail!("Unquantized value") }; + let (Some(scale), Some(zp)) = (qp.scale(), qp.zero_point()) else { bail!("No ZP/scale found") }; + Ok((zp.iter().map(|i| i as i32).collect_vec(), scale.iter().collect_vec())) +} + pub fn flat_tensor_to_tract_fact<'m>( &model: &'m Model<'m>, graph: &'m SubGraph<'m>, id: i32, ) -> TractResult<(TypedFact, &'m str)> { let flat = graph.tensors().unwrap().get(id as _); - let dt: DatumType = flat.type_().try_into()?; + let mut dt: DatumType = flat.type_().try_into()?; + if let Some(qp) = flat.quantization() { + if let (Some(scale), Some(zp)) = (qp.scale(), qp.zero_point()) { + dt = dt.quantize(QParams::ZpScale { zero_point: zp.get(0) as _, scale: scale.get(0) }) + } + } let mut fact = dt.fact(flat.shape().unwrap().iter().map(|d| d as usize).collect_vec()); let buffer_ix = flat.buffer() as usize; if buffer_ix != 0 { let buffer = model.buffers().unwrap().get(flat.buffer() as usize); if let Some(data) = buffer.data() { - let data = - create_tensor(fact.datum_type, fact.shape.as_concrete().unwrap(), data.bytes())?; + let mut data = create_tensor( + fact.datum_type.unquantized(), + fact.shape.as_concrete().unwrap(), + data.bytes(), + )?; + unsafe { + data.set_datum_type(dt); + }; fact = data.into(); } } Ok((fact, flat.name().unwrap())) } + +#[derive(Clone, Debug)] +pub struct PerAxisQ { + axis: usize, + zp: Vec, + scale: Vec, +}