Skip to content

Commit

Permalink
wip breaking up pack
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 20, 2024
1 parent 464ea78 commit 79239f8
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 88 deletions.
55 changes: 35 additions & 20 deletions core/src/ops/einsum/kernel_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ use tract_linalg::mmm::MatMatMul;

use crate::internal::*;
use crate::ops::matmul::de_block_quant::BlockQuantValue;
use crate::ops::matmul::pack::Packer;

#[derive(Debug)]
pub struct Strategy {
pub static_packing_for_a: Packer,
pub static_packing_for_b: Packer,
pub scenarios: Vec<(Box<dyn MatMatMul>, usize, Packer, Packer)>,
}

pub fn select_kernel_and_packing(
model: &TypedModel,
Expand All @@ -13,7 +21,7 @@ pub fn select_kernel_and_packing(
k: &TDim,
n: &TDim,
operating_dt: DatumType,
) -> TractResult<Vec<(Box<dyn MatMatMul>, usize)>> {
) -> TractResult<Strategy> {
let input_facts = model.node_input_facts(node.id)?;
let a_dt = input_facts[0].datum_type;
let b_dt = input_facts[1].datum_type;
Expand All @@ -31,51 +39,58 @@ pub fn select_kernel_and_packing(
.map(|(s, _)| n.eval_with_scenario(&s))
.collect_vec();
let _need_matvec = all_n_values.iter().any(|n| n.is_one());
// println!("{m} {n} {all_n_values:?} {need_matvec:?}");
let mut options: Vec<(&Box<dyn MatMatMul>, usize)> = vec![];
println!("{m} {n} {all_n_values:?} {_need_matvec:?}");
let mut options: Vec<(&Box<dyn MatMatMul>, usize, &PackedBlockQuantFormat, &PackedFormat)> =
vec![];
for imp in tract_linalg::ops().mmm_impls() {
for (packing, (pack_a, pack_b)) in imp.packings().iter().enumerate() {
// println!("{imp:?} {packing} {pack_a} {pack_b}");
if let (Some(input), Some(b)) = (
println!("{imp:?} {packing} {pack_a} {pack_b}");
if let (Some(pa), Some(pb)) = (
pack_a.downcast_ref::<PackedBlockQuantFormat>(),
pack_b.downcast_ref::<PackedFormat>(),
) {
if input.bq.same_as(&*bqv.fact.format) && b.dt == b_dt {
options.push((imp, packing));
if pa.bq.same_as(&*bqv.fact.format) && pb.dt == b_dt {
options.push((imp, packing, pa, pb));
}
}
}
}
if options.len() > 0 {
let pair = if let (Some(m), Some(n)) = (m.as_i64(), n.as_i64()) {
let (mmm, packing, pa, pb) = if let (Some(m), Some(n)) = (m.as_i64(), n.as_i64()) {
options
.iter()
.into_iter()
.min_by_key(|a| {
((m as usize).divceil(a.0.mr()) * (n as usize).divceil(a.0.nr()))
* (a.0.mr() * a.0.nr() + 100)
})
.unwrap()
} else {
options.iter().max_by_key(|a| a.0.mr() * a.0.nr()).unwrap()
options.into_iter().max_by_key(|a| a.0.mr() * a.0.nr()).unwrap()
};
return Ok(vec!((pair.0.clone(), pair.1)));
return Ok(Strategy {
static_packing_for_a: Packer::PackBlockQuant(pa.clone()),
static_packing_for_b: Packer::Regular(pb.clone()),
scenarios: vec![(mmm.clone(), packing, Packer::Identity, Packer::Identity)],
});
}
}

// "simple" kernel selection
let mmm = tract_linalg::ops()
.mmm(operating_dt, m.to_usize().ok(), k.to_usize().ok(), n.to_usize().ok())
.unwrap();
let packing = mmm
let (packing, pa, pb) = mmm
.packings()
.iter()
.position(|p| {
p.0.downcast_ref::<PackedFormat>()
.is_some_and(|pf| pf.dt.unquantized() == a_dt.unquantized())
&& p.1
.downcast_ref::<PackedFormat>()
.is_some_and(|pf| pf.dt.unquantized() == b_dt.unquantized())
.into_iter()
.enumerate()
.filter_map(|(ix, p)| {
Some((ix, p.0.downcast_ref::<PackedFormat>()?, p.1.downcast_ref::<PackedFormat>()?))
})
.find(|(_ix, pa, pb)| pa.dt == a_dt.unquantized() && pb.dt == b_dt.unquantized())
.with_context(|| format!("No packing for {mmm:?} with inputs {a_dt:?} and {b_dt:?}"))?;
Ok(vec!((mmm, packing)))
Ok(Strategy {
static_packing_for_a: Packer::Regular(pa.clone()),
static_packing_for_b: Packer::Regular(pb.clone()),
scenarios: vec![(mmm, packing, Packer::Identity, Packer::Identity)],
})
}
141 changes: 81 additions & 60 deletions core/src/ops/einsum/optimize.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
use kernel_selection::select_kernel_and_packing;
use tract_linalg::frame::block_quant::PackedBlockQuantFormat;
use tract_linalg::frame::PackedFormat;
use tract_linalg::mmm::{MMMInputFormat, MMMInputValue};

use super::*;
use crate::ops::cast::cast;
use crate::ops::math::add;
use crate::ops::matmul::de_block_quant::BlockQuantValue;
use crate::ops::matmul::optimized::{
AddMatMulGeometry, MapOutputAxisToInput, OptMatMul, ProtoFusedSpec,
};
use crate::ops::matmul::pack::{MatMatMulPack, Packer};
use crate::ops::matmul::pack::MatMatMulPack;
use crate::ops::matmul::quant::{
combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
};
Expand Down Expand Up @@ -298,49 +294,54 @@ fn dequant(
Ok(Some(patch))
}

fn wire_packing(
model: &TypedModel,
node: &TypedNode,
input: usize,
patch: &mut TypedModelPatch,
packer: &[&dyn MMMInputFormat],
k_axis: usize,
mn_axis: usize,
) -> TractResult<OutletId> {
let name = format!("{}.pack_{}", node.name, ['a', 'b'][input]);
let a_fact = model.outlet_fact(node.inputs[0])?;
if let Some(packers) = packer
.iter()
.map(|p| p.downcast_ref::<PackedFormat>().cloned().map(|f| Packer::Regular(f)))
.collect::<Option<Vec<_>>>()
{
let wire = patch.tap_model(model, node.inputs[input])?;
let pack_a = MatMatMulPack { packers, k_axis, mn_axis };
Ok(patch.wire_node(&name, pack_a, &[wire])?[0])
/*
} else if let (Some(bqf), Some(pbqf)) = (
a_fact.opaque_fact.as_ref().and_then(|of| of.downcast_ref::<BlockQuantFact>()),
packer.downcast_ref::<PackedBlockQuantFormat>(),
) {
ensure!(k_axis == 1);
ensure!(mn_axis == 0);
ensure!(pbqf.bq.same_as(&*bqf.format));
let Some(weights) = &a_fact.konst else {
bail!("Block quant packing with non-const inputs")
};
ensure!(weights.datum_type() == Opaque::datum_type());
let Some(weights) = weights.to_scalar::<Opaque>()?.downcast_ref::<BlockQuantValue>() else {
bail!("Expected a BlockQuantValue, found {weights:?}")
};
let k = bqf.shape[k_axis].to_usize()?;
let packed = pbqf.pack(&weights.value, k)?;
let mmm_input: Box<dyn MMMInputValue> = Box::new(packed);
patch.add_const(name, tensor0(Opaque::from(mmm_input)))
*/
} else {
bail!("Unexpected packing format: {:?}", packer);
}
/*
fn wire_packing(
model: &TypedModel,
node: &TypedNode,
input: usize,
patch: &mut TypedModelPatch,
packers: Vec<Packer>,
k_axis: usize,
mn_axis: usize,
) -> TractResult<OutletId> {
let name = format!("{}.pack_{}", node.name, ['a', 'b'][input]);
// let a_fact = model.outlet_fact(node.inputs[0])?;
let wire = patch.tap_model(model, node.inputs[input])?;
let pack_a = MatMatMulPack { packers, k_axis, mn_axis };
Ok(patch.wire_node(&name, pack_a, &[wire])?[0])
/*
if let Some(packers) = packer
.iter()
.map(|p| p.downcast_ref::<PackedFormat>().cloned().map(|f| Packer::Regular(f)))
.collect::<Option<Vec<_>>>()
{
let wire = patch.tap_model(model, node.inputs[input])?;
let pack_a = MatMatMulPack { packers, k_axis, mn_axis };
Ok(patch.wire_node(&name, pack_a, &[wire])?[0])
} else if let (Some(bqf), Some(pbqf)) = (
a_fact.opaque_fact.as_ref().and_then(|of| of.downcast_ref::<BlockQuantFact>()),
packer.downcast_ref::<PackedBlockQuantFormat>(),
) {
ensure!(k_axis == 1);
ensure!(mn_axis == 0);
ensure!(pbqf.bq.same_as(&*bqf.format));
let Some(weights) = &a_fact.konst else {
bail!("Block quant packing with non-const inputs")
};
ensure!(weights.datum_type() == Opaque::datum_type());
let Some(weights) = weights.to_scalar::<Opaque>()?.downcast_ref::<BlockQuantValue>() else {
bail!("Expected a BlockQuantValue, found {weights:?}")
};
let k = bqf.shape[k_axis].to_usize()?;
let packed = pbqf.pack(&weights.value, k)?;
let mmm_input: Box<dyn MMMInputValue> = Box::new(packed);
patch.add_const(name, tensor0(Opaque::from(mmm_input)))
} else {
bail!("Unexpected packing format: {:?}", packer);
}
*/
}
*/

fn optimized_mat_mul(
op: &EinSum,
Expand Down Expand Up @@ -383,17 +384,38 @@ fn optimized_mat_mul(
.map(Some);
}

let selection = select_kernel_and_packing(model, node, m, k, n, op.operating_dt)?;
let (packers_a, packers_b): (Vec<_>, Vec<_>) =
selection.iter().map(|(mm, pack)| mm.packings()[*pack]).unzip();

let strategy = select_kernel_and_packing(model, node, m, k, n, op.operating_dt)?;
let name = &node.name;
let mut patch = TypedModelPatch::new("Einsum to OptMatMul");

let pa = wire_packing(model, node, 0, &mut patch, &packers_a, a_k, a_m)
.with_context(|| format!("Wiring packing {:?} for a: {:?}", &packers_a, input_facts[0]))?;
let pb = wire_packing(model, node, 1, &mut patch, &packers_b, b_k, b_n)
.with_context(|| format!("Wiring packing {:?} for b: {:?}", &packers_b, input_facts[1]))?;

let taps = patch.taps(model, &node.inputs)?;
let pack_a_0 = patch.wire_node(
format!("{name}.pack_a_0"),
MatMatMulPack { k_axis: a_k, mn_axis: a_m, packers: vec![strategy.static_packing_for_a] },
&[taps[0]],
)?;
let pack_a = patch.wire_node(
format!("{name}.pack_a"),
MatMatMulPack {
k_axis: a_k,
mn_axis: a_m,
packers: strategy.scenarios.iter().map(|s| s.2.clone()).collect(),
},
&pack_a_0,
)?;
let pack_b_0 = patch.wire_node(
format!("{name}.pack_b_0"),
MatMatMulPack { k_axis: b_k, mn_axis: b_n, packers: vec![strategy.static_packing_for_b] },
&[taps[1]],
)?;
let pack_b = patch.wire_node(
format!("{name}.pack_b"),
MatMatMulPack {
k_axis: b_k,
mn_axis: b_n,
packers: strategy.scenarios.iter().map(|s| s.3.clone()).collect(),
},
&pack_b_0,
)?;
let mut c_to_a_axis_mapping = tvec!();
let mut c_to_b_axis_mapping = tvec!();
for axis in op.axes.iter_all_axes().filter(|&axis| ![m_axis, k_axis, n_axis].contains(&axis)) {
Expand All @@ -412,13 +434,12 @@ fn optimized_mat_mul(
}

let c_fact = op.output_facts(&input_facts)?.remove(0);
let name = &node.name;
let geo = AddMatMulGeometry {
k: k.clone(),
c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
};
let (mmm, packing) = selection[0].clone();
let (mmm, packing, _, _) = strategy.scenarios[0].clone();
let output = unsafe { mmm.c_view(c_m, c_n) };
let opt = OptMatMul::new(
mmm,
Expand All @@ -428,7 +449,7 @@ fn optimized_mat_mul(
vec![ProtoFusedSpec::AddMatMul { geo, a: 0, b: 1, packing }, ProtoFusedSpec::Store(output)],
)
.context("Creating OptMatMul")?;
let output = patch.wire_node(name, opt, &[pa, pb])?[0];
let output = patch.wire_node(name, opt, &[pack_a[0], pack_b[0]])?[0];
patch.shunt_outside(model, node.id.into(), output)?;
Ok(Some(patch))
}
34 changes: 27 additions & 7 deletions core/src/ops/matmul/pack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ use crate::axes::Axis;
use crate::internal::*;
use ndarray::*;
use tract_data::TooEarly;
use tract_itertools::Itertools;
use tract_linalg::frame::block_quant::PackedBlockQuantFormat;
use tract_linalg::frame::PackedFormat;
use tract_linalg::mmm::MMMInputValue;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Packer {
Regular(PackedFormat),
PackBlockQuant(PackedBlockQuantFormat),
PanelExtractionWrapper,
Identity,
}
Expand All @@ -21,7 +24,7 @@ impl Packer {
) -> TractResult<Box<dyn MMMInputValue>> {
match self {
Packer::Regular(format) => format.pack_tensor_view(view, k_axis, mn_axis),
_ => todo!(),
_ => todo!("Missing packer {self:?}"),
}
}
}
Expand Down Expand Up @@ -54,9 +57,9 @@ impl EvalOp for MatMatMulPack {
fn eval_with_session(
&self,
session: &SessionState,
inputs: TVec<TValue>,
mut inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
self.do_eval(session, &inputs[0])
self.do_eval(session, inputs.remove(0))
}
}

Expand All @@ -82,11 +85,22 @@ impl TypedOp for MatMatMulPack {
AxesMapping::new(1, 1, axes)
}

fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.packers.iter().all(|p| *p == Packer::Identity) {
return TypedModelPatch::shunt_one_op(model, node)
}
Ok(None)
}

as_op!();
}

impl MatMatMulPack {
fn do_eval(&self, session: &SessionState, input: &Tensor) -> TractResult<TVec<TValue>> {
fn do_eval(&self, session: &SessionState, input: TValue) -> TractResult<TVec<TValue>> {
unsafe {
let packer = if self.packers.len() == 1 {
&self.packers[0]
Expand All @@ -95,6 +109,9 @@ impl MatMatMulPack {
} else {
bail!(TooEarly::Other("Undetermined scenario".into()))
};
if *packer == Packer::Identity {
return Ok(tvec!(input))
}
let output_shape: TVec<usize> = self.output_shape(input.shape());
let stores = if output_shape.iter().all(|d| *d == 1) {
tensor0::<Opaque>(
Expand All @@ -121,7 +138,7 @@ impl MatMatMulPack {
pack_coords.remove(self.k_axis.min(self.mn_axis));
stores_view[&*pack_coords] = packer
.pack_tensor_view(
&TensorView::from_bytes(input, offset, input.shape(), input.strides()),
&TensorView::from_bytes(&input, offset, input.shape(), input.strides()),
self.k_axis,
self.mn_axis,
)?
Expand All @@ -134,9 +151,12 @@ impl MatMatMulPack {
}

pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
assert!(self.packers.iter().map(|p| matches!(p, Packer::Regular(_))).all_equal());
let mut packed_shape: TVec<D> = input.into();
packed_shape.remove(self.mn_axis.max(self.k_axis));
packed_shape.remove(self.mn_axis.min(self.k_axis));
if matches!(self.packers[0], Packer::Regular(_)) {
packed_shape.remove(self.mn_axis.max(self.k_axis));
packed_shape.remove(self.mn_axis.min(self.k_axis));
}
packed_shape
}
}
Loading

0 comments on commit 79239f8

Please sign in to comment.