Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More tflite support: Q Conv #1140

Merged
merged 25 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .travis/regular-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,17 @@ then
exit 0
fi

OLD_CACHEDIR=$CACHEDIR
mkdir -p $CACHEDIR/big
export CACHEDIR=$CACHEDIR/big
cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p core-proptest-pulse $ALL_FEATURES
cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p lstm-proptest-onnx-vs-tf $ALL_FEATURES
cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p nnef-inceptionv3 $ALL_FEATURES
cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p tf-inceptionv3 $ALL_FEATURES
cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p tf-mobilenet-v2 $ALL_FEATURES
cargo -q test $CARGO_EXTRA -q --profile opt-no-lto -p tf-moz-deepspeech $ALL_FEATURES
if [ -n "$GITHUB_ACTIONS" ]
then
rm -r $OLD_CACHEDIR/big
fi
CACHEDIR=$OLD_CACHEDIR
38 changes: 22 additions & 16 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl ConvUnary {
use crate::ops::matmul::mir_quant as qmm;

let c_dt = self.q_params.unwrap();
let [a0, a_scale, mut b0, b_scale, c0, c_scale] = wires[1..] else {
let [a0, mut a_scale, mut b0, b_scale, c0, c_scale] = wires[1..] else {
bail!("Wrong number of inputs")
};

Expand All @@ -168,6 +168,18 @@ impl ConvUnary {
let (_, m, k, n, mmm) = self.compute_geo(&b_fact)?;
let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;

if !model.outlet_fact(a_scale)?.shape.volume().is_one() {
// requant is performed before geo_reshape, so we need at most one geo axis to the
// right
if !output_shape.fmt.c_is_last() {
a_scale = model.wire_node(
format!("{name}.a_scale_axis_fix"),
AxisOp::Add(1),
&[a_scale],
)?[0];
}
}

let abc_scale = qmm::combine_scales(model, name, a_scale, b_scale, c_scale)?;

let im2col = model.wire_node(
Expand Down Expand Up @@ -641,7 +653,7 @@ impl ConvUnary {
return Ok(None);
};
let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
if value.cast_to_scalar::<i64>()? != 0
if !value.is_zero()?
|| (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
|| pad.pads[shape.c_axis()] != (0, 0)
{
Expand Down Expand Up @@ -671,27 +683,21 @@ impl ConvUnary {
if self.q_params.is_some() || self.group != 1 {
return Ok(None);
}
let &[succ] = &*node.outputs[0].successors else {
return Ok(None)
};
let Some(bin) = model.node(succ.node).op_as::<TypedBinOp>() else {
return Ok(None)
};
let &[succ] = &*node.outputs[0].successors else { return Ok(None) };
let Some(bin) = model.node(succ.node).op_as::<TypedBinOp>() else { return Ok(None) };
let other_input = model.node(succ.node).inputs[1 - succ.slot];
let other_fact = &model.outlet_fact(other_input)?;
let Some(konst) = &other_fact.konst else {
return Ok(None)
};
let Some(konst) = &other_fact.konst else { return Ok(None) };
let axes_mapping = model.node_axes_mapping(succ.node)?;
let input_shape =
self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
let conv_c_axis = input_shape.c_axis();
let &[konst_c_axis] = &*axes_mapping.axis((InOut::In(succ.slot), conv_c_axis))?.inputs[1- succ.slot] else {
return Ok(None)
};
let Ok(co) = node.outputs[0].fact.shape[conv_c_axis].to_usize() else {
return Ok(None)
let &[konst_c_axis] =
&*axes_mapping.axis((InOut::In(succ.slot), conv_c_axis))?.inputs[1 - succ.slot]
else {
return Ok(None);
};
let Ok(co) = node.outputs[0].fact.shape[conv_c_axis].to_usize() else { return Ok(None) };
let operand_for_bias = if konst.shape()[konst_c_axis] == co && konst.len() == co {
konst.clone().into_tensor().into_shape(&[co])?
} else if konst.len() == 1 {
Expand Down
22 changes: 20 additions & 2 deletions core/src/ops/einsum/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ pub(super) fn ensure_mkn_axes<'a>(
})
.max_by_key(|a| &output_shape[a.outputs[0][0]]);
let Some(n_axis) = n_axis else {
return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, true, &[k_axis, m_axis])?));
return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(
op,
model,
node,
true,
&[k_axis, m_axis],
)?));
};
for axis in op.axes.iter_all_axes() {
let one = TDim::one();
Expand Down Expand Up @@ -218,10 +224,22 @@ fn dequant_output(
let name = &node.name;
let mut patch = TypedModelPatch::new("Dequantizing einsum");
let taps = patch.taps(model, &node.inputs)?;
let [a, b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
let [a, b, bias, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
bail!("Expect exactly 9 inputs")
};

if !patch.outlet_fact(a_scale)?.shape.volume().is_one() {
let q_axis_in_output = op.axes.axis((InOut::In(4), 0))?.outputs[0][0];
let output_rank = node.outputs[0].fact.rank();
for i in 1..(output_rank - q_axis_in_output) {
a_scale = patch.wire_node(
format!("{name}.a_scale_axis_fix_{i}"),
AxisOp::Add(i),
&[a_scale],
)?[0];
}
}

let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut a0, "a0")?;
let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut b0, "b0")?;

Expand Down
10 changes: 10 additions & 0 deletions data/src/datum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ impl DatumType {
_ => *self,
}
}

pub fn quantize(&self, qparams: QParams) -> DatumType {
match self {
DatumType::I8 => DatumType::QI8(qparams),
DatumType::U8 => DatumType::QI8(qparams),
DatumType::I32=> DatumType::QI32(qparams),
_ => panic!("Can't quantize {self:?}"),
}
}

#[inline(always)]
pub fn zp_scale(&self) -> (i32, f32) {
self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
Expand Down
3 changes: 2 additions & 1 deletion data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ impl Approximation {
fn atol_and_rtol(&self, dt: &DatumType) -> (f64, f64) {
use Approximation::*;
match (self, dt) {
(Exact, _) => (0.0, 0.0),
(Close, DatumType::F16) => (1e-3, 1e-3),
(Approximate, DatumType::F16) => (1e-3, 5e-3),
(Exact, _) => (0.0, 0.0),
(Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0.),
(Close, _) => (1e-7, 1e-7),
(Approximate, _) => (1e-4, 5e-4),
}
Expand Down
50 changes: 43 additions & 7 deletions test-rt/infra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use dyn_clone::DynClone;
use itertools::Itertools;
use proptest::prelude::{any_with, Arbitrary};
use proptest::test_runner::{Config, FileFailurePersistence, TestRunner};
use tract_core::internal::Approximation;
use tract_core::runtime::Runtime;
use tract_core::tract_data::TractResult;

Expand All @@ -16,7 +17,11 @@ pub fn setup_test_logger() {
pub type TestResult = anyhow::Result<()>;

pub trait Test: 'static + Send + Sync + DynClone {
fn run(&self, runtime: &dyn Runtime) -> TestResult;
fn run(&self, id: &str, runtime: &dyn Runtime) -> TestResult {
self.run_with_approx(id, runtime, Approximation::Close)
}
fn run_with_approx(&self, id: &str, runtime: &dyn Runtime, approx: Approximation)
-> TestResult;
}

dyn_clone::clone_trait_object!(Test);
Expand Down Expand Up @@ -87,6 +92,32 @@ impl TestSuite {
}
}

pub fn get_sub(&self, id: &str) -> &TestSuite {
match self {
TestSuite::Node(n) => {
if let Some((head, tail)) = id.split_once("::") {
n[head].get_sub(tail)
} else {
n[id].get_sub("")
}
}
TestSuite::Leaf(_, _) => panic!(),
}
}

pub fn get_sub_mut(&mut self, id: &str) -> &mut TestSuite {
match self {
TestSuite::Node(n) => {
if let Some((head, tail)) = id.split_once("::") {
n.get_mut(head).unwrap().get_sub_mut(tail)
} else {
n.get_mut(id).unwrap()
}
}
TestSuite::Leaf(_, _) => panic!(),
}
}

fn ignore_rec(&mut self, prefix: &mut Vec<String>, ign: &dyn Fn(&[String]) -> bool) {
match self {
TestSuite::Node(n) => {
Expand All @@ -111,6 +142,7 @@ impl TestSuite {
prefix: &str,
id: &str,
rs: &mut impl Write,
approx: &str,
) -> TractResult<()> {
let full_id = [prefix, id].into_iter().filter(|s| s.len() > 0).join("::");
match self {
Expand All @@ -120,7 +152,7 @@ impl TestSuite {
writeln!(rs, "use super::*;").unwrap();
}
for (id, test) in h.iter().sorted_by_key(|(k, _)| k.to_owned()) {
test.dump(test_suite, runtime, &full_id, id, rs)?;
test.dump(test_suite, runtime, &full_id, id, rs, approx)?;
}
if id.len() > 0 {
writeln!(rs, "}}").unwrap();
Expand All @@ -133,21 +165,25 @@ impl TestSuite {
writeln!(rs, "#[ignore]").unwrap();
}
writeln!(rs, "fn {id}() -> TractResult<()> {{",).unwrap();
writeln!(rs, " {test_suite}.get({full_id:?}).run({runtime})",).unwrap();
writeln!(
rs,
" {test_suite}.get({full_id:?}).run_with_approx({full_id:?}, {runtime}, {approx})",
)
.unwrap();
writeln!(rs, "}}").unwrap();
}
}
Ok(())
}

pub fn test_runtime(&self, name: &str, test_suite: &str, runtime: &str) {
pub fn test_runtime(&self, name: &str, test_suite: &str, runtime: &str, approx: &str) {
let out_dir = std::env::var("OUT_DIR").unwrap();
let out_dir = std::path::PathBuf::from(out_dir);
let test_dir = out_dir.join("tests");
std::fs::create_dir_all(&test_dir).unwrap();
let test_file = test_dir.join(name).with_extension("rs");
let mut rs = std::fs::File::create(test_file).unwrap();
self.dump(test_suite, runtime, "", "", &mut rs).unwrap();
self.dump(test_suite, runtime, "", "", &mut rs, approx).unwrap();
}
}

Expand All @@ -160,13 +196,13 @@ impl<A: Arbitrary + Test + Clone> Test for ProptestWrapper<A>
where
A::Parameters: Clone + Send + Sync,
{
fn run(&self, runtime: &dyn Runtime) -> TestResult {
fn run_with_approx(&self, id: &str, runtime: &dyn Runtime, approx: Approximation) -> TestResult {
let mut runner = TestRunner::new(Config {
failure_persistence: Some(Box::new(FileFailurePersistence::Off)),
..Config::default()
});
runner.run(&any_with::<A>(self.0.clone()), |v| {
v.run(runtime).unwrap();
v.run_with_approx(id, runtime, approx).unwrap();
Ok(())
})?;
Ok(())
Expand Down
14 changes: 10 additions & 4 deletions test-rt/suite-conv/src/conv_f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,18 @@ impl Arbitrary for ConvProblem {
}

impl Test for ConvProblem {
fn run(&self, runtime: &dyn Runtime) -> TestResult {
fn run_with_approx(
&self,
id: &str,
runtime: &dyn Runtime,
approx: Approximation,
) -> TestResult {
let reference = self.reference().into_tensor();
let mut output =
runtime.prepare(self.tract()?)?.run(tvec![self.data.clone().into_tvalue()])?;
let mut model = self.tract()?;
model.properties.insert("tract-rt-test.id".to_string(), rctensor0(id.to_string()));
let mut output = runtime.prepare(model)?.run(tvec![self.data.clone().into_tvalue()])?;
let output = output.remove(0).into_tensor();
output.close_enough(&reference, true)
output.close_enough(&reference, approx)
}
}

Expand Down
Loading