Skip to content

Commit

Permalink
Merge c281b86 into 68a96cd
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jul 10, 2024
2 parents 68a96cd + c281b86 commit adf8f3b
Show file tree
Hide file tree
Showing 38 changed files with 1,873 additions and 665 deletions.
10 changes: 5 additions & 5 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
use crate::ops::matmul::lir_unary::{LirMatMulUnary, ProtoFusedSpec};
use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};

use tract_linalg::frame::Packer;
use tract_linalg::frame::PackedFormat;
use tract_linalg::mmm::MatMatMul;

#[derive(Debug, Clone, new, Hash)]
Expand Down Expand Up @@ -74,7 +74,7 @@ impl Conv {
&self,
model: &mut TypedModel,
name: &str,
packer: Packer,
packer: PackedFormat,
kernel: OutletId,
) -> TractResult<OutletId> {
Ok(model.wire_node(
Expand Down Expand Up @@ -177,7 +177,7 @@ impl Conv {
&sum_ker_n_g_c,
)?;

ensure!(mmm.packings()[packing].1.downcast_ref::<Packer>().is_some());
ensure!(mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some());
let mut sum_x = model.wire_node(
format!("{name}.sum_x"),
super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
Expand Down Expand Up @@ -405,7 +405,7 @@ impl Conv {
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
let packer = mmm.packings()[packing]
.1
.downcast_ref::<Packer>()
.downcast_ref::<PackedFormat>()
.with_context(|| {
format_err!(
"Quand Im2Col expects regular packed format, got {:?}",
Expand Down Expand Up @@ -486,7 +486,7 @@ impl Conv {
ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type());
let a_pack = mmm.packings()[packing]
.0
.downcast_ref::<Packer>()
.downcast_ref::<PackedFormat>()
.context("Conv expects wights in regular packed format")?
.clone();
let packed_ker = self
Expand Down
18 changes: 9 additions & 9 deletions core/src/ops/cnn/conv/im2col.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use tract_linalg::frame::{MatMatMul, Packer, PackingWriter};
use tract_linalg::mmm::{EagerPackedInput, MMMInput};
use tract_linalg::frame::{MatMatMul, PackedFormat, PackingWriter};
use tract_linalg::mmm::{EagerPackedInput, MMMInputValue};

use crate::internal::*;
use ndarray::prelude::*;
Expand All @@ -21,7 +21,7 @@ struct SymbolicGeometry {
group: usize,
pool_spec: PoolSpec,
pool_geometry: PoolGeometry,
b_pack: Packer,
b_pack: PackedFormat,
k: usize,
}

Expand All @@ -30,15 +30,15 @@ struct ConcreteGeometry {
pool: ConcretePoolGeometry,
pub n: usize,
k: usize,
pub b_pack: Packer,
pub b_pack: PackedFormat,
pub ci_per_group: usize,
patcher: Patcher,
input_shape_with_n: DataShape,
packed_shape: TVec<usize>, // always Batch,Group,Packed
}

impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
pub fn b_pack(&self) -> &Packer {
pub fn b_pack(&self) -> &PackedFormat {
match self {
GeometryBound::Symbolic(s) => &s.b_pack,
GeometryBound::Concrete(s) => &s.b_pack,
Expand Down Expand Up @@ -104,7 +104,7 @@ impl Im2Col {
) -> TractResult<Im2Col> {
let b_pack = mmm.packings()[0]
.1
.downcast_ref::<Packer>()
.downcast_ref::<PackedFormat>()
.context("Im2Col expects regular packed format")?
.clone();

Expand Down Expand Up @@ -178,12 +178,12 @@ impl EvalOp for Im2Col {
g,
pad_value
))?;
let input: Box<dyn MMMInput> = Box::new(EagerPackedInput {
packed: data,
let input: Box<dyn MMMInputValue> = Box::new(EagerPackedInput {
format: Box::new(geometry.b_pack.clone()),
packed: data.into_blob()?,
panel_bytes,
k: geometry.k,
mn: geometry.n,
r: geometry.b_pack.r,
});
output_view[[i, g]] = input.into();
}
Expand Down
14 changes: 7 additions & 7 deletions core/src/ops/cnn/conv/lazy_im2col.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::internal::*;
use std::fmt::{Debug, Display};
use std::ops::Range;
use tract_linalg::frame::{Packer, PackingWriter};
use tract_linalg::mmm::MMMInput;
use tract_linalg::frame::{PackedFormat, PackingWriter};
use tract_linalg::mmm::MMMInputValue;

#[derive(Clone, Debug, Hash, PartialEq)]
pub struct LazyIm2colParams {
pub packer: Packer,
pub packer: PackedFormat,
pub n_byte_offsets: Vec<isize>,
pub k_byte_offsets: Vec<isize>,
}
Expand All @@ -32,7 +32,7 @@ impl EvalOp for LazyIm2Col {

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let tensor = args_1!(inputs);
let input: Box<dyn MMMInput> =
let input: Box<dyn MMMInputValue> =
Box::new(LazyIm2colInput { tensor, im2col: self.params.clone() });
let input = Opaque(Arc::new(input));
Ok(tvec!(tensor2(&[[input]]).into_tvalue()))
Expand Down Expand Up @@ -289,14 +289,14 @@ impl LazyIm2colInput {
}
}

impl MMMInput for LazyIm2colInput {
impl MMMInputValue for LazyIm2colInput {
fn scratch_panel_buffer_layout(&self) -> Option<std::alloc::Layout> {
let k = self.im2col.k_byte_offsets.len();
Some(self.im2col.packer.single_panel_layout(k, self.tensor.datum_type().size_of()))
}

fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> *const u8 {
dispatch_copy!(Self::do_panel(self.tensor.datum_type())(self, i, buffer))
fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> {
Ok(dispatch_copy!(Self::do_panel(self.tensor.datum_type())(self, i, buffer)))
}

fn k(&self) -> usize {
Expand Down
8 changes: 4 additions & 4 deletions core/src/ops/cnn/conv/q_sum_b.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::internal::*;
use tract_linalg::mmm::MMMInput;
use tract_linalg::mmm::MMMInputValue;
use tract_ndarray::prelude::*;

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
Expand Down Expand Up @@ -61,7 +61,7 @@ impl QSumB {
let mut output_view = output_view.index_axis_mut(Axis(0), g);
let output_slice = output_view.as_slice_mut().unwrap();
let payload = payloads[[b, g]]
.downcast_ref::<Box<dyn MMMInput>>()
.downcast_ref::<Box<dyn MMMInputValue>>()
.context("Expected MMMInputs")?;
match self.dt.unquantized() {
DatumType::I8 => self.eval_t::<i8>(&**payload, output_slice)?,
Expand All @@ -75,13 +75,13 @@ impl QSumB {

fn eval_t<T: Datum + tract_num_traits::AsPrimitive<i32>>(
&self,
input: &dyn MMMInput,
input: &dyn MMMInputValue,
output: &mut [i32],
) -> TractResult<()> {
let (r, k, n) = (input.r(), input.k(), input.mn());
let panels = n.divceil(r);
for ipanel in 0..panels {
let panel = input.panel_bytes(ipanel, None);
let panel = input.panel_bytes(ipanel, None)?;
let panel: &[T] = unsafe { std::slice::from_raw_parts(panel as *const T, r * k) };
let mut vec = vec![0i32; r];
for ik in 0..k {
Expand Down
126 changes: 100 additions & 26 deletions core/src/ops/einsum/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use tract_linalg::frame::Packer;
use tract_linalg::frame::block_quant::PackedBlockQuantFormat;
use tract_linalg::frame::PackedFormat;
use tract_linalg::mmm::{MMMInputFormat, MMMInputValue, MatMatMul};

use super::*;
use crate::ops::cast::cast;
use crate::ops::math::add;
use crate::ops::matmul::de_block_quant::DeBlockQuant;
use crate::ops::matmul::lir_unary::{
AddMatMulGeometry, LirMatMulUnary, MapOutputAxisToInput, ProtoFusedSpec,
};
Expand Down Expand Up @@ -41,7 +44,7 @@ pub(crate) fn codegen(
}
}

pub(super) fn ensure_mkn_axes<'a>(
pub(crate) fn ensure_mkn_axes<'a>(
op: &'a EinSum,
model: &TypedModel,
node: &TypedNode,
Expand Down Expand Up @@ -292,6 +295,77 @@ fn dequant(
Ok(Some(patch))
}

fn select_kernel_and_packing(
model: &TypedModel,
node: &TypedNode,
m: &TDim,
n: &TDim,
) -> TractResult<Option<(Box<dyn MatMatMul>, usize)>> {
if let Some(dbq) = model.node(node.inputs[0].node).op_as::<DeBlockQuant>() {
let mut options: Vec<(&Box<dyn MatMatMul>, usize)> = vec![];
for imp in tract_linalg::ops().mmm_impls() {
for (packing, (pack_a, _)) in imp.packings().iter().enumerate() {
if let Some(input) = pack_a.downcast_ref::<PackedBlockQuantFormat>() {
if input.bq.same_as(&*dbq.bq) {
options.push((imp, packing));
}
}
}
}
if options.len() > 0 {
let pair = if let (Some(m), Some(n)) = (m.as_i64(), n.as_i64()) {
options
.iter()
.min_by_key(|a| {
((m as usize).divceil(a.0.mr()) * (n as usize).divceil(a.0.nr()))
* a.0.mr()
* a.0.nr()
})
.unwrap()
} else {
options.iter().max_by_key(|a| a.0.mr() * a.0.nr()).unwrap()
};
return Ok(Some((pair.0.clone(), pair.1)));
}
}
Ok(None)
}

fn wire_packing(
model: &TypedModel,
node: &TypedNode,
input: usize,
patch: &mut TypedModelPatch,
packer: &dyn MMMInputFormat,
k_axis: usize,
mn_axis: usize,
) -> TractResult<OutletId> {
let name = format!("{}.pack_{}", node.name, ['a', 'b'][input]);
if let Some(packed_format) = packer.downcast_ref::<PackedFormat>().cloned() {
let wire = patch.tap_model(model, node.inputs[input])?;
let pack_a = MatMatMulPack { packer: packed_format, k_axis, mn_axis };
Ok(patch.wire_node(&name, pack_a, &[wire])?[0])
} else if let Some(pbqf) = packer.downcast_ref::<PackedBlockQuantFormat>() {
ensure!(k_axis == 1);
ensure!(mn_axis == 0);
let prec = model.node(node.inputs[input].node);
ensure!(prec.outputs[0].successors.len() == 1);
let Some(deq) = prec.op_as::<DeBlockQuant>() else {
bail!("Precursor should be dequantizer")
};
ensure!(deq.bq.same_as(&*pbqf.bq));
let Some(weights) = &model.outlet_fact(prec.inputs[0])?.konst else {
bail!("Block quant packing with non-const inputs")
};
let k = deq.fact.shape[k_axis].to_usize()?;
let packed = deq.bq.pack(weights.to_scalar::<Blob>()?, k, pbqf.r)?;
let mmm_input: Box<dyn MMMInputValue> = Box::new(packed);
patch.add_const(name, tensor0(Opaque::from(mmm_input)))
} else {
bail!("Unexpected packing format: {:?}", packer);
}
}

fn lir_mat_mul_unary(
op: &EinSum,
model: &TypedModel,
Expand Down Expand Up @@ -331,32 +405,32 @@ fn lir_mat_mul_unary(
)
.map(Some);
}
let a_dt = input_facts[0].datum_type;
let b_dt = input_facts[1].datum_type;
let dt = op.operating_dt;
let mmm = tract_linalg::ops()
.mmm(a_dt, b_dt, dt, m.to_usize().ok(), k.to_usize().ok(), n.to_usize().ok())
.unwrap();
let name = &node.name;

let (mmm, packing) = if let Some(pair) = select_kernel_and_packing(model, node, m, n)? {
pair
} else {
let a_dt = input_facts[0].datum_type;
let b_dt = input_facts[1].datum_type;
let dt = op.operating_dt;
let mmm = tract_linalg::ops()
.mmm(a_dt, b_dt, dt, m.to_usize().ok(), k.to_usize().ok(), n.to_usize().ok())
.unwrap();
let packing = mmm
.packings()
.iter()
.position(|p| {
p.0.can_prepare_types().contains(&a_dt.unquantized())
&& p.1.can_prepare_types().contains(&b_dt.unquantized())
})
.with_context(|| format!("No packing for {mmm:?} with inputs {a_dt:?} and {b_dt:?}"))?;
(mmm, packing)
};

let mut patch = TypedModelPatch::new("Einsum to LirMatMulUnary");
let a = patch.tap_model(model, node.inputs[0])?;
let b = patch.tap_model(model, node.inputs[1])?;
let packing = mmm
.packings()
.iter()
.position(|p| {
p.0.can_prepare_types().contains(&a_dt.unquantized()) && p.1.can_prepare_types().contains(&b_dt.unquantized())
})
.with_context(|| format!("No packing for {mmm:?} with inputs {a_dt:?} and {b_dt:?}"))?;
let packers = mmm.packings()[packing];
let a_pack =
packers.0.downcast_ref::<Packer>().context("Expects regular packed format for A")?.clone();
let b_pack =
packers.1.downcast_ref::<Packer>().context("Expects regular packed format for B")?.clone();
let pack_a = MatMatMulPack { packer: a_pack, k_axis: a_k, mn_axis: a_m };
let pack_b = MatMatMulPack { packer: b_pack, k_axis: b_k, mn_axis: b_n };
let pa = patch.wire_node(format!("{name}.pack_a"), pack_a, &[a])?[0];
let pb = patch.wire_node(format!("{name}.pack_b"), pack_b, &[b])?[0];

let pa = wire_packing(model, node, 0, &mut patch, packers.0, a_k, a_m)?;
let pb = wire_packing(model, node, 1, &mut patch, packers.1, b_k, b_n)?;

let mut c_to_a_axis_mapping = tvec!();
let mut c_to_b_axis_mapping = tvec!();
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/einsum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod as_blas;
use super::array::TypedConcat;
use super::math::add;
mod as_matmul;
mod codegen;
pub mod codegen;

#[cfg(test)]
mod proptest;
Expand Down
1 change: 1 addition & 0 deletions core/src/ops/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod de_block_quant;
pub mod lir_unary;
pub mod mir_quant;
pub mod pack;
Expand Down
Loading

0 comments on commit adf8f3b

Please sign in to comment.