diff --git a/core/src/ops/cnn/conv/conv.rs b/core/src/ops/cnn/conv/conv.rs index 936a06d285..c729418461 100644 --- a/core/src/ops/cnn/conv/conv.rs +++ b/core/src/ops/cnn/conv/conv.rs @@ -28,7 +28,7 @@ use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec}; use crate::ops::matmul::lir_unary::{LirMatMulUnary, ProtoFusedSpec}; use crate::ops::nn::{BaseDataShape, DataFormat, DataShape}; -use tract_linalg::frame::Packer; +use tract_linalg::frame::PackedFormat; use tract_linalg::mmm::MatMatMul; #[derive(Debug, Clone, new, Hash)] @@ -74,7 +74,7 @@ impl Conv { &self, model: &mut TypedModel, name: &str, - packer: Packer, + packer: PackedFormat, kernel: OutletId, ) -> TractResult { Ok(model.wire_node( @@ -177,7 +177,7 @@ impl Conv { &sum_ker_n_g_c, )?; - ensure!(mmm.packings()[packing].1.downcast_ref::().is_some()); + ensure!(mmm.packings()[packing].1.downcast_ref::().is_some()); let mut sum_x = model.wire_node( format!("{name}.sum_x"), super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k }, @@ -405,7 +405,7 @@ impl Conv { let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?; let packer = mmm.packings()[packing] .1 - .downcast_ref::() + .downcast_ref::() .with_context(|| { format_err!( "Quand Im2Col expects regular packed format, got {:?}", @@ -486,7 +486,7 @@ impl Conv { ensure!(model.outlet_fact(bias)?.datum_type == mmm.internal_type()); let a_pack = mmm.packings()[packing] .0 - .downcast_ref::() + .downcast_ref::() .context("Conv expects wights in regular packed format")? .clone(); let packed_ker = self diff --git a/core/src/ops/cnn/conv/im2col.rs b/core/src/ops/cnn/conv/im2col.rs index 1fa0cd7478..be7ca17030 100644 --- a/core/src/ops/cnn/conv/im2col.rs +++ b/core/src/ops/cnn/conv/im2col.rs @@ -1,5 +1,5 @@ -use tract_linalg::frame::{MatMatMul, Packer, PackingWriter}; -use tract_linalg::mmm::{EagerPackedInput, MMMInput}; +use tract_linalg::frame::{MatMatMul, PackedFormat, PackingWriter}; +use tract_linalg::mmm::{EagerPackedInput, MMMInputValue}; use crate::internal::*; use ndarray::prelude::*; @@ -21,7 +21,7 @@ struct SymbolicGeometry { group: usize, pool_spec: PoolSpec, pool_geometry: PoolGeometry, - b_pack: Packer, + b_pack: PackedFormat, k: usize, } @@ -30,7 +30,7 @@ struct ConcreteGeometry { pool: ConcretePoolGeometry, pub n: usize, k: usize, - pub b_pack: Packer, + pub b_pack: PackedFormat, pub ci_per_group: usize, patcher: Patcher, input_shape_with_n: DataShape, @@ -38,7 +38,7 @@ struct ConcreteGeometry { } impl GeometryBound { - pub fn b_pack(&self) -> &Packer { + pub fn b_pack(&self) -> &PackedFormat { match self { GeometryBound::Symbolic(s) => &s.b_pack, GeometryBound::Concrete(s) => &s.b_pack, @@ -104,7 +104,7 @@ impl Im2Col { ) -> TractResult { let b_pack = mmm.packings()[0] .1 - .downcast_ref::() + .downcast_ref::() .context("Im2Col expects regular packed format")? .clone(); @@ -178,12 +178,12 @@ impl EvalOp for Im2Col { g, pad_value ))?; - let input: Box = Box::new(EagerPackedInput { - packed: data, + let input: Box = Box::new(EagerPackedInput { + format: Box::new(geometry.b_pack.clone()), + packed: data.into_blob()?, panel_bytes, k: geometry.k, mn: geometry.n, - r: geometry.b_pack.r, }); output_view[[i, g]] = input.into(); } diff --git a/core/src/ops/cnn/conv/lazy_im2col.rs b/core/src/ops/cnn/conv/lazy_im2col.rs index e81ae13c8c..8ae4dd2537 100644 --- a/core/src/ops/cnn/conv/lazy_im2col.rs +++ b/core/src/ops/cnn/conv/lazy_im2col.rs @@ -1,12 +1,12 @@ use crate::internal::*; use std::fmt::{Debug, Display}; use std::ops::Range; -use tract_linalg::frame::{Packer, PackingWriter}; -use tract_linalg::mmm::MMMInput; +use tract_linalg::frame::{PackedFormat, PackingWriter}; +use tract_linalg::mmm::MMMInputValue; #[derive(Clone, Debug, Hash, PartialEq)] pub struct LazyIm2colParams { - pub packer: Packer, + pub packer: PackedFormat, pub n_byte_offsets: Vec, pub k_byte_offsets: Vec, } @@ -32,7 +32,7 @@ impl EvalOp for LazyIm2Col { fn eval(&self, inputs: TVec) -> TractResult> { let tensor = args_1!(inputs); - let input: Box = + let input: Box = Box::new(LazyIm2colInput { tensor, im2col: self.params.clone() }); let input = Opaque(Arc::new(input)); Ok(tvec!(tensor2(&[[input]]).into_tvalue())) @@ -289,14 +289,14 @@ impl LazyIm2colInput { } } -impl MMMInput for LazyIm2colInput { +impl MMMInputValue for LazyIm2colInput { fn scratch_panel_buffer_layout(&self) -> Option { let k = self.im2col.k_byte_offsets.len(); Some(self.im2col.packer.single_panel_layout(k, self.tensor.datum_type().size_of())) } - fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> *const u8 { - dispatch_copy!(Self::do_panel(self.tensor.datum_type())(self, i, buffer)) + fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> { + Ok(dispatch_copy!(Self::do_panel(self.tensor.datum_type())(self, i, buffer))) } fn k(&self) -> usize { diff --git a/core/src/ops/cnn/conv/q_sum_b.rs b/core/src/ops/cnn/conv/q_sum_b.rs index 971e413359..50836a7f0b 100644 --- a/core/src/ops/cnn/conv/q_sum_b.rs +++ b/core/src/ops/cnn/conv/q_sum_b.rs @@ -1,5 +1,5 @@ use crate::internal::*; -use tract_linalg::mmm::MMMInput; +use tract_linalg::mmm::MMMInputValue; use tract_ndarray::prelude::*; #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -61,7 +61,7 @@ impl QSumB { let mut output_view = output_view.index_axis_mut(Axis(0), g); let output_slice = output_view.as_slice_mut().unwrap(); let payload = payloads[[b, g]] - .downcast_ref::>() + .downcast_ref::>() .context("Expected MMMInputs")?; match self.dt.unquantized() { DatumType::I8 => self.eval_t::(&**payload, output_slice)?, @@ -75,13 +75,13 @@ impl QSumB { fn eval_t>( &self, - input: &dyn MMMInput, + input: &dyn MMMInputValue, output: &mut [i32], ) -> TractResult<()> { let (r, k, n) = (input.r(), input.k(), input.mn()); let panels = n.divceil(r); for ipanel in 0..panels { - let panel = input.panel_bytes(ipanel, None); + let panel = input.panel_bytes(ipanel, None)?; let panel: &[T] = unsafe { std::slice::from_raw_parts(panel as *const T, r * k) }; let mut vec = vec![0i32; r]; for ik in 0..k { diff --git a/core/src/ops/einsum/codegen.rs b/core/src/ops/einsum/codegen.rs index 23170b0b52..f6e1bf5a74 100644 --- a/core/src/ops/einsum/codegen.rs +++ b/core/src/ops/einsum/codegen.rs @@ -1,8 +1,11 @@ -use tract_linalg::frame::Packer; +use tract_linalg::frame::block_quant::PackedBlockQuantFormat; +use tract_linalg::frame::PackedFormat; +use tract_linalg::mmm::{MMMInputFormat, MMMInputValue, MatMatMul}; use super::*; use crate::ops::cast::cast; use crate::ops::math::add; +use crate::ops::matmul::de_block_quant::DeBlockQuant; use crate::ops::matmul::lir_unary::{ AddMatMulGeometry, LirMatMulUnary, MapOutputAxisToInput, ProtoFusedSpec, }; @@ -41,7 +44,7 @@ pub(crate) fn codegen( } } -pub(super) fn ensure_mkn_axes<'a>( +pub(crate) fn ensure_mkn_axes<'a>( op: &'a EinSum, model: &TypedModel, node: &TypedNode, @@ -292,6 +295,77 @@ fn dequant( Ok(Some(patch)) } +fn select_kernel_and_packing( + model: &TypedModel, + node: &TypedNode, + m: &TDim, + n: &TDim, +) -> TractResult, usize)>> { + if let Some(dbq) = model.node(node.inputs[0].node).op_as::() { + let mut options: Vec<(&Box, usize)> = vec![]; + for imp in tract_linalg::ops().mmm_impls() { + for (packing, (pack_a, _)) in imp.packings().iter().enumerate() { + if let Some(input) = pack_a.downcast_ref::() { + if input.bq.same_as(&*dbq.bq) { + options.push((imp, packing)); + } + } + } + } + if options.len() > 0 { + let pair = if let (Some(m), Some(n)) = (m.as_i64(), n.as_i64()) { + options + .iter() + .min_by_key(|a| { + ((m as usize).divceil(a.0.mr()) * (n as usize).divceil(a.0.nr())) + * a.0.mr() + * a.0.nr() + }) + .unwrap() + } else { + options.iter().max_by_key(|a| a.0.mr() * a.0.nr()).unwrap() + }; + return Ok(Some((pair.0.clone(), pair.1))); + } + } + Ok(None) +} + +fn wire_packing( + model: &TypedModel, + node: &TypedNode, + input: usize, + patch: &mut TypedModelPatch, + packer: &dyn MMMInputFormat, + k_axis: usize, + mn_axis: usize, +) -> TractResult { + let name = format!("{}.pack_{}", node.name, ['a', 'b'][input]); + if let Some(packed_format) = packer.downcast_ref::().cloned() { + let wire = patch.tap_model(model, node.inputs[input])?; + let pack_a = MatMatMulPack { packer: packed_format, k_axis, mn_axis }; + Ok(patch.wire_node(&name, pack_a, &[wire])?[0]) + } else if let Some(pbqf) = packer.downcast_ref::() { + ensure!(k_axis == 1); + ensure!(mn_axis == 0); + let prec = model.node(node.inputs[input].node); + ensure!(prec.outputs[0].successors.len() == 1); + let Some(deq) = prec.op_as::() else { + bail!("Precursor should be dequantizer") + }; + ensure!(deq.bq.same_as(&*pbqf.bq)); + let Some(weights) = &model.outlet_fact(prec.inputs[0])?.konst else { + bail!("Block quant packing with non-const inputs") + }; + let k = deq.fact.shape[k_axis].to_usize()?; + let packed = deq.bq.pack(weights.to_scalar::()?, k, pbqf.r)?; + let mmm_input: Box = Box::new(packed); + patch.add_const(name, tensor0(Opaque::from(mmm_input))) + } else { + bail!("Unexpected packing format: {:?}", packer); + } +} + fn lir_mat_mul_unary( op: &EinSum, model: &TypedModel, @@ -331,32 +405,32 @@ fn lir_mat_mul_unary( ) .map(Some); } - let a_dt = input_facts[0].datum_type; - let b_dt = input_facts[1].datum_type; - let dt = op.operating_dt; - let mmm = tract_linalg::ops() - .mmm(a_dt, b_dt, dt, m.to_usize().ok(), k.to_usize().ok(), n.to_usize().ok()) - .unwrap(); - let name = &node.name; + + let (mmm, packing) = if let Some(pair) = select_kernel_and_packing(model, node, m, n)? { + pair + } else { + let a_dt = input_facts[0].datum_type; + let b_dt = input_facts[1].datum_type; + let dt = op.operating_dt; + let mmm = tract_linalg::ops() + .mmm(a_dt, b_dt, dt, m.to_usize().ok(), k.to_usize().ok(), n.to_usize().ok()) + .unwrap(); + let packing = mmm + .packings() + .iter() + .position(|p| { + p.0.can_prepare_types().contains(&a_dt.unquantized()) + && p.1.can_prepare_types().contains(&b_dt.unquantized()) + }) + .with_context(|| format!("No packing for {mmm:?} with inputs {a_dt:?} and {b_dt:?}"))?; + (mmm, packing) + }; + let mut patch = TypedModelPatch::new("Einsum to LirMatMulUnary"); - let a = patch.tap_model(model, node.inputs[0])?; - let b = patch.tap_model(model, node.inputs[1])?; - let packing = mmm - .packings() - .iter() - .position(|p| { - p.0.can_prepare_types().contains(&a_dt.unquantized()) && p.1.can_prepare_types().contains(&b_dt.unquantized()) - }) - .with_context(|| format!("No packing for {mmm:?} with inputs {a_dt:?} and {b_dt:?}"))?; let packers = mmm.packings()[packing]; - let a_pack = - packers.0.downcast_ref::().context("Expects regular packed format for A")?.clone(); - let b_pack = - packers.1.downcast_ref::().context("Expects regular packed format for B")?.clone(); - let pack_a = MatMatMulPack { packer: a_pack, k_axis: a_k, mn_axis: a_m }; - let pack_b = MatMatMulPack { packer: b_pack, k_axis: b_k, mn_axis: b_n }; - let pa = patch.wire_node(format!("{name}.pack_a"), pack_a, &[a])?[0]; - let pb = patch.wire_node(format!("{name}.pack_b"), pack_b, &[b])?[0]; + + let pa = wire_packing(model, node, 0, &mut patch, packers.0, a_k, a_m)?; + let pb = wire_packing(model, node, 1, &mut patch, packers.1, b_k, b_n)?; let mut c_to_a_axis_mapping = tvec!(); let mut c_to_b_axis_mapping = tvec!(); diff --git a/core/src/ops/einsum/mod.rs b/core/src/ops/einsum/mod.rs index 0ae478d8fc..31543445ec 100644 --- a/core/src/ops/einsum/mod.rs +++ b/core/src/ops/einsum/mod.rs @@ -11,7 +11,7 @@ pub mod as_blas; use super::array::TypedConcat; use super::math::add; mod as_matmul; -mod codegen; +pub mod codegen; #[cfg(test)] mod proptest; diff --git a/core/src/ops/matmul.rs b/core/src/ops/matmul.rs index a672694bba..19ddebe013 100644 --- a/core/src/ops/matmul.rs +++ b/core/src/ops/matmul.rs @@ -1,3 +1,4 @@ +pub mod de_block_quant; pub mod lir_unary; pub mod mir_quant; pub mod pack; diff --git a/core/src/ops/matmul/de_block_quant.rs b/core/src/ops/matmul/de_block_quant.rs new file mode 100644 index 0000000000..f2e7157ee5 --- /dev/null +++ b/core/src/ops/matmul/de_block_quant.rs @@ -0,0 +1,136 @@ +use tract_linalg::frame::block_quant::{BlockQuant, Q4_0}; + +use crate::internal::*; +use crate::ops::einsum::codegen::{ensure_mkn_axes, AxesOrPatch}; +use crate::ops::einsum::EinSum; +use crate::transform::ModelTransform; + +#[derive(Debug)] +pub struct BlockQuantTransform; + +impl ModelTransform for BlockQuantTransform { + fn name(&self) -> Cow { + "BlockQuantTransform".into() + } + + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + Rewriter::<()>::default() + .with_rule_for("block_quant_einsum_weights", block_quant_einsum_weights) + .rewrite(&(), model) + } +} + +fn block_quant_einsum_weights( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + prefix: &str, + op: &EinSum, +) -> TractResult> { + let &[a, b] = &*model.node_input_facts(node.id)? else { return Ok(None) }; + if b.konst.is_some() { + let mut new_axes = op.axes.clone().with_extra_input(2)?; + for (ix, axis) in op.axes.axes(InOut::In(0)).enumerate() { + new_axes = new_axes.with_extra_axis_occurency(axis.repr, InOut::In(2), ix)?; + } + new_axes = new_axes.remove_slot(InOut::In(0))?; + return Ok(Some(TypedModelPatch::replace_single_op( + model, + node, + &[node.inputs[1], node.inputs[0]], + EinSum { axes: new_axes, ..op.clone() }, + )?)); + } + if a.konst.is_none() || a.rank() != 2 { + return Ok(None); + } + let AxesOrPatch::Axes(m, k, _n) = ensure_mkn_axes(op, model, node)? else { return Ok(None) }; + if m.inputs[0][0] == 1 && k.inputs[0][0] == 0 { + let a: &Tensor = a.konst.as_ref().unwrap(); + let mut patch = TypedModelPatch::default(); + let konst = + patch.add_const(&model.node(node.inputs[0].node).name, a.clone().move_axis(1, 0)?)?; + let axes = op + .axes + .clone() + .with_extra_axis_occurency(k, InOut::In(0), 2)? + .remove_axis_occurency(InOut::In(0), 0)?; + let tap = patch.tap_model(model, node.inputs[1])?; + let output = patch.wire_node(prefix, EinSum { axes, ..op.clone() }, &[konst, tap])?; + patch.shunt_outside(model, node.id.into(), output[0])?; + return Ok(Some(patch)); + } + let mut patch = TypedModelPatch::default(); + let weights = Q4_0.quant_f32(a.konst.as_ref().unwrap().cast_to::()?.as_slice::()?)?; + let name = &model.node(node.inputs[0].node).name; + let weights = patch.add_const(name, tensor0(weights))?; + let weights = patch.wire_node( + format!("{name}.deq"), + DeBlockQuant { bq: Box::new(Q4_0), fact: a.without_value() }, + &[weights], + )?; + patch.shunt_outside(model, node.inputs[0], weights[0])?; + patch.obliterate(node.inputs[0].node)?; + Ok(Some(patch)) +} + +#[derive(Debug, Clone, Hash)] +pub struct DeBlockQuant { + pub(crate) bq: Box, + pub(crate) fact: TypedFact, +} + +impl Op for DeBlockQuant { + fn name(&self) -> Cow { + "DeBlockQuant".into() + } + + fn same_as(&self, other: &dyn Op) -> bool { + other + .downcast_ref::() + .map(|other| other.bq.same_as(&*self.bq) && other.fact == self.fact) + .unwrap_or(false) + } + + op_as_typed_op!(); +} + +impl EvalOp for DeBlockQuant { + fn is_stateless(&self) -> bool { + false + } + + fn eval(&self, _inputs: TVec) -> TractResult> { + unreachable!() + } + + fn state( + &self, + _session: &mut SessionState, + _node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(self.clone()))) + } +} + +impl OpState for DeBlockQuant { + fn eval( + &mut self, + _session: &mut SessionState, + _op: &dyn Op, + inputs: TVec, + ) -> TractResult> { + let input = args_1!(inputs); + let blob = input.to_scalar::()?; + Ok(tvec!(self.bq.dequant_f32(blob)?.into_tvalue())) + } +} + +trivial_op_state_freeeze!(DeBlockQuant); + +impl TypedOp for DeBlockQuant { + fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult> { + Ok(tvec!(self.fact.clone())) + } + as_op!(); +} diff --git a/core/src/ops/matmul/lir_unary.rs b/core/src/ops/matmul/lir_unary.rs index 7f6925437d..8ea45d1861 100644 --- a/core/src/ops/matmul/lir_unary.rs +++ b/core/src/ops/matmul/lir_unary.rs @@ -5,7 +5,7 @@ use crate::ops::nn::LeakyRelu; use ndarray::*; use tract_itertools::Itertools; -use tract_linalg::mmm::{BinOp, FusedSpec, MMMInput, MatMatMul, OutputStoreSpec}; +use tract_linalg::mmm::{BinOp, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec}; use tract_linalg::Scaler; use tract_smallvec::ToSmallVec; @@ -23,10 +23,13 @@ pub enum ProtoFusedSpec { } impl ProtoFusedSpec { - pub fn name(&self) -> String { + pub fn format(&self, mmm: &dyn MatMatMul) -> String { use ProtoFusedSpec::*; match self { - AddMatMul { geo, .. } => format!("matmul(k={})", geo.k), + AddMatMul { geo, packing, .. } => { + let (a,b) = mmm.packings()[*packing]; + format!("matmul(k={}, {a:?}•{b:?})", geo.k) + }, BinScalar(_, op) => format!("scalar{op:?}"), LeakyRelu(alpha) => format!("leaky_relu({alpha:?})"), BinPerRow(_, op, _) => format!("row{op:?}"), @@ -51,13 +54,13 @@ impl ProtoFusedSpec { geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a); } let a = - a.as_slice::().unwrap()[0].downcast_ref::>().unwrap(); + a.as_slice::().unwrap()[0].downcast_ref::>().unwrap(); let mut b = inputs[*b].view(); unsafe { geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b); } let b = - b.as_slice::().unwrap()[0].downcast_ref::>().unwrap(); + b.as_slice::().unwrap()[0].downcast_ref::>().unwrap(); FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing } } ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op), @@ -106,9 +109,9 @@ impl ProtoFusedSpec { let a = &inputs[*a]; let b = &inputs[*b]; let a = - a.to_scalar::().unwrap().downcast_ref::>().unwrap(); + a.to_scalar::().unwrap().downcast_ref::>().unwrap(); let b = - b.to_scalar::().unwrap().downcast_ref::>().unwrap(); + b.to_scalar::().unwrap().downcast_ref::>().unwrap(); FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing } } ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op), @@ -273,7 +276,7 @@ impl Op for LirMatMulUnary { } else { infos.push(format!("Mult: {:?}", self.mmm)); } - infos.push(format!("Ops: {}", self.micro_ops.iter().map(|o| o.name()).join(" >>> "))); + infos.push(format!("Ops: {}", self.micro_ops.iter().map(|o| o.format(&*self.mmm)).join(" >>> "))); Ok(infos) } diff --git a/core/src/ops/matmul/pack.rs b/core/src/ops/matmul/pack.rs index 3b080f86cf..2b2dd71f8a 100644 --- a/core/src/ops/matmul/pack.rs +++ b/core/src/ops/matmul/pack.rs @@ -1,12 +1,15 @@ use crate::axes::Axis; use crate::internal::*; use ndarray::*; +use tract_linalg::frame::block_quant::RepackingPackedBlockQuantValue; +use tract_linalg::frame::PackedFormat; +use tract_linalg::mmm::MMMInputValue; -use tract_linalg::frame::Packer; +use super::de_block_quant::DeBlockQuant; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MatMatMulPack { - pub(crate) packer: Packer, + pub(crate) packer: PackedFormat, pub(crate) k_axis: usize, pub(crate) mn_axis: usize, } @@ -56,6 +59,10 @@ impl TypedOp for MatMatMulPack { AxesMapping::new(1, 1, axes) } + fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult> { + self.fuse_dequant_block(model, node) + } + as_op!(); } @@ -99,6 +106,27 @@ impl MatMatMulPack { } } + fn fuse_dequant_block( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + let Some(prec) = model.single_prec(node.id)? else { return Ok(None) }; + let Some(deq) = prec.op_as::() else { return Ok(None) }; + let Some(weights) = model.outlet_fact(prec.inputs[0])?.konst.as_ref() else { + return Ok(None); + }; + let k = deq.fact.shape[self.k_axis].to_usize().unwrap(); + let value = deq.bq.pack(weights.to_scalar::()?, k, self.packer.r)?; + let mmm_input: Box = + Box::new(RepackingPackedBlockQuantValue { value, pack: self.packer.clone() }); + let packed = tensor0(Opaque::from(mmm_input)).into_arc_tensor(); + let mut patch = TypedModelPatch::default(); + let wire = patch.add_const(&node.name, packed)?; + patch.shunt_outside(model, node.id.into(), wire)?; + Ok(Some(patch)) + } + pub fn output_shape(&self, input: &[D]) -> TVec { let mut packed_shape: TVec = input.into(); packed_shape.remove(self.mn_axis.max(self.k_axis)); diff --git a/core/src/transform.rs b/core/src/transform.rs index ce6e3ffddd..9befcd7c32 100644 --- a/core/src/transform.rs +++ b/core/src/transform.rs @@ -1,6 +1,7 @@ use crate::internal::*; #[cfg(feature = "blas")] use crate::ops::einsum::as_blas::AsBlas; +use crate::ops::matmul::de_block_quant::BlockQuantTransform; use num_traits::Float; use std::borrow::Cow; use std::fmt::Debug; @@ -21,6 +22,7 @@ pub fn get_transform(name: &str) -> Option> { build_float_translator::(name.strip_prefix("f16-to-f32")) } "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)), + "block-quant" => Some(Box::new(BlockQuantTransform)), _ => None, } } diff --git a/data/src/blob.rs b/data/src/blob.rs index 758fac0f02..e4142ce298 100644 --- a/data/src/blob.rs +++ b/data/src/blob.rs @@ -73,7 +73,7 @@ impl Blob { let mut data = null_mut(); if layout.size() > 0 { data = unsafe { alloc(layout) }; - assert!(!data.is_null()); + assert!(!data.is_null(), "failed to allocate {layout:?}"); } Blob { layout, data } } @@ -139,10 +139,14 @@ impl std::fmt::Display for Blob { assert!(self.data.is_null() == self.layout.size().is_zero()); write!( fmt, - "Blob of {} bytes (align @{}): {}", + "Blob of {} bytes (align @{}): {} {}", self.len(), self.layout.align(), - String::from_utf8_lossy(self) + String::from_utf8( + self.iter().take(20).copied().flat_map(std::ascii::escape_default).collect::>() + ) + .unwrap(), + if self.len() >= 20 { "[...]" } else { "" } ) } } diff --git a/data/src/datum.rs b/data/src/datum.rs index 5f8a7b9aef..702d2e5231 100644 --- a/data/src/datum.rs +++ b/data/src/datum.rs @@ -343,10 +343,10 @@ impl DatumType { #[inline] pub fn alignment(&self) -> usize { - match self { - DatumType::TDim => std::mem::size_of::(), - DatumType::String => std::mem::size_of::(), - _ => self.size_of(), + if self.is_copy() { + self.size_of() + } else { + std::mem::size_of::() } } diff --git a/data/src/tensor.rs b/data/src/tensor.rs index e08c88fea9..e8ee38444c 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -1526,6 +1526,11 @@ impl Tensor { compute_natural_stride_to(&mut strides, shape); strides } + + pub fn into_blob(mut self) -> TractResult { + ensure!(self.dt.is_copy()); + Ok(std::mem::take(&mut self.data)) + } } impl PartialEq for Tensor { diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 58dcb59539..b6200c2c16 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -14,6 +14,7 @@ edition = "2021" maintenance = { status = "actively-developed" } [dependencies] +byteorder.workspace = true derive-new.workspace = true downcast-rs.workspace = true dyn-clone.workspace = true diff --git a/linalg/arm64/arm64fp16/arm64fp16_mmm_f16_128x1_core.tmpl b/linalg/arm64/arm64fp16/arm64fp16_mmm_f16_128x1_core.tmpl index e0c68f6613..15fb315fbb 100644 --- a/linalg/arm64/arm64fp16/arm64fp16_mmm_f16_128x1_core.tmpl +++ b/linalg/arm64/arm64fp16/arm64fp16_mmm_f16_128x1_core.tmpl @@ -1,16 +1,6 @@ // vim: ft=arm // C tile regs: v16 to v31, no need to preserve -// -// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] -// v16[1] v18[1] -// v16[2] v18[2] -// v16[3] v18[3] -// -// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] -// v17[1] v19[1] -// v17[2] v19[2] -// v17[3] v19[3] // no preservation either for v0-v7... // v8..v15 are callee-preserved diff --git a/linalg/arm64/arm64fp16/arm64fp16_mmm_f16_64x1.core.tmpl b/linalg/arm64/arm64fp16/arm64fp16_mmm_f16_64x1.core.tmpl new file mode 100644 index 0000000000..b77e638460 --- /dev/null +++ b/linalg/arm64/arm64fp16/arm64fp16_mmm_f16_64x1.core.tmpl @@ -0,0 +1,194 @@ +// vim: ft=arm + +// C tile regs: v16 to v31, no need to preserve + +// no preservation either for v0-v7... +// v8..v15 are callee-preserved +// packed A buffering (2x8 values): alternating v0, v1 with v2, v3 +// packed B buffering (2x8 values): alternating v4, v5 with v6, v7 + +.text +.align 4 + +{% if needs_pragma == true %} +.cpu generic+fp+simd+fp16 +{% endif %} +.global {{G}}arm64fp16_mmm_f16_64x1_{{core}}_{{suffix}} +{{G}}arm64fp16_mmm_f16_64x1_{{core}}_{{suffix}}: + + stp x20, x21, [sp, #-16]! + stp x22, x23, [sp, #-16]! + stp x24, x25, [sp, #-16]! + + stp d8, d9, [sp, #-16]! + stp d10, d11, [sp, #-16]! + stp d12, d13, [sp, #-16]! + stp d14, d15, [sp, #-16]! + +{% include "dispatcher.tmpliq" %} + +.add_mat_mul: + ldp x2, x4, [x0, #24] // b, packing + ldp x3, x1, [x0, #8] // k, a + + cmp x3, #0 + beq .non_linear_loop + + cmp x4, #1 + beq .q4f16 + +.p2align 4 +.packed_packed_loop_1: + ld1 { v8.h }[0], [ x2 ], #2 + ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [ x1 ], #64 + ld1 { v4.8h, v5.8h, v6.8h, v7.8h }, [ x1 ], #64 + + fmla v24.8h, v0.8h, v8.h[0] + fmla v25.8h, v1.8h, v8.h[0] + fmla v26.8h, v2.8h, v8.h[0] + fmla v27.8h, v3.8h, v8.h[0] + fmla v28.8h, v4.8h, v8.h[0] + fmla v29.8h, v5.8h, v8.h[0] + fmla v30.8h, v6.8h, v8.h[0] + fmla v31.8h, v7.8h, v8.h[0] + subs x3, x3, #1 + bge .packed_packed_loop_1 + + b .non_linear_loop + +.q4f16: + movi v14.16b, 8 + movi v15.16b, 15 +.q4f16_outerloop: + // scales + ld1 { v16.8h-v19.8h }, [ x1 ], #64 + ld1 { v20.8h-v23.8h }, [ x1 ], #64 + mov x4, #32 + +.p2align 4 +.q4f16_innerloop: + ld1 { v9.16b-v10.16b }, [x1], #32 + ld1 { v8.h }[0], [ x2 ], #2 + + and v11.16b, v9.16b, v15.16b + ushr v12.16b, v9.16b, 4 + zip1 v0.16b, v11.16b, v12.16b + zip2 v2.16b, v11.16b, v12.16b + + and v11.16b, v10.16b, v15.16b + ushr v12.16b, v10.16b, 4 + zip1 v4.16b, v11.16b, v12.16b + zip2 v6.16b, v11.16b, v12.16b + + sub v0.16b, v0.16b, v14.16b + sub v2.16b, v2.16b, v14.16b + sub v4.16b, v4.16b, v14.16b + sub v6.16b, v6.16b, v14.16b + +{% for i in (0..3) %} + sshll2 v{{i|times:2|plus:1}}.8h, v{{i|times:2}}.16b, 0 + sshll v{{i|times:2}}.8h, v{{i|times: 2}}.8b, 0 +{% endfor %} + +{% for i in (0..7) %} + scvtf v{{i}}.8h, v{{i}}.8h +{% endfor %} + +{% for i in (0..7) %} + fmul v{{i}}.8h, v{{i}}.8h, v{{i|plus:16}}.8h +{% endfor %} + +{% for i in (0..7) %} + fmla v{{ i|plus: 24 }}.8h, v{{i}}.8h, v8.h[0] +{% endfor %} + + subs x4, x4, #1 + bne .q4f16_innerloop + + subs x3, x3, #32 + bne .q4f16_outerloop + + b .non_linear_loop + +{% include "arm64fp16_mmm_f16_scalars.tmpliq" from:24, to:31%} +{% include "arm64fp16_mmm_f16_per_rows.tmpliq" mr:64, from:24, to:31%} +{% include "arm64fp16_mmm_f16_per_cols.tmpliq" mr:64, from:24, to:31%} + +.add_unicast: + ldp x5, x6, [x0, #8] // c base ptr, rsc + cmp x6, #2 + beq .do_per_row_add + + {% for reg in (24..31) %} + {% for lane in (0..7) %} + ld1 {v0.h}[{{lane}}], [ x5 ], x6 + {% endfor %} + fadd v{{reg}}.8h, v{{reg}}.8h, v0.8h + {% endfor %} + + b .non_linear_loop + +.do_per_row_add: + ld1 {v0.8h-v3.8h}, [x5], #64 + ld1 {v4.8h-v7.8h}, [x5], #64 + + {% for r in (0..7) %} + fadd v{{r| plus: 24}}.8h, v{{r | plus: 24}}.8h, v{{r}}.8h + {% endfor %} + + b .non_linear_loop + +.add_row_col_products: + ldr x3, [x0, #16] + ldr x2, [x0, #8] + + ld1 {v8.h}[0], [ x3 ] + + {% for r in (0..7) %} + ldr q{{r}}, [x2], #16 + {% endfor %} + + fmla v24.8h, v0.8h, v8.h[0] + fmla v25.8h, v1.8h, v8.h[0] + fmla v26.8h, v2.8h, v8.h[0] + fmla v27.8h, v3.8h, v8.h[0] + fmla v28.8h, v4.8h, v8.h[0] + fmla v29.8h, v5.8h, v8.h[0] + fmla v30.8h, v6.8h, v8.h[0] + fmla v31.8h, v7.8h, v8.h[0] + + b .non_linear_loop + +.store: + ldp x5, x6, [x0, #8] // c base ptr, rsc$ + + cmp x6, #2 + beq .store_strides_contig + + {% for reg in (24..31) %} + {% for lane in (0..7) %} + st1 { v{{reg}}.h }[{{lane}}], [ x5 ], x6 + {% endfor %} + {% endfor %} + b .non_linear_loop + +.store_strides_contig: + + {% for reg in (24..31) %} + st1 { v{{reg}}.8h }, [ x5 ], #16 + {% endfor %} + b .non_linear_loop + +.return: + + ldp d14, d15, [sp], #16 + ldp d12, d13, [sp], #16 + ldp d10, d11, [sp], #16 + ldp d8, d9, [sp], #16 + + ldp x24, x25, [sp], #16 + ldp x22, x23, [sp], #16 + ldp x20, x21, [sp], #16 + + ret + diff --git a/linalg/benches/utils.rs b/linalg/benches/utils.rs index b45310deeb..ad1163f9c4 100644 --- a/linalg/benches/utils.rs +++ b/linalg/benches/utils.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use criterion::*; use tract_data::internal::*; -use tract_linalg::frame::mmm::{FusedSpec, MMMInput}; +use tract_linalg::frame::mmm::{FusedSpec, MMMInputValue}; use tract_linalg::frame::MatMatMul; use DatumType::*; @@ -38,8 +38,8 @@ unsafe fn run( n: usize, be: &mut Bencher, mmm: &dyn MatMatMul, - a: &dyn MMMInput, - b: &dyn MMMInput, + a: &dyn MMMInputValue, + b: &dyn MMMInputValue, cold: bool, ) { let mut scratch = mmm.allocate_scratch_space(); diff --git a/linalg/src/arm32/armv7neon.rs b/linalg/src/arm32/armv7neon.rs index dcb2c6a9dc..6bbaf0bcee 100644 --- a/linalg/src/arm32/armv7neon.rs +++ b/linalg/src/arm32/armv7neon.rs @@ -17,22 +17,22 @@ MMMExternKernel!(f32, armv7neon_mmm_f32_32x1_generic; 32, 1; 4, 4; 0, 0; prefetc MMMExternKernel!(i32, armv7neon_mmm_i32_8x4; 8, 4; 32, 4; 0, 0; prefetch, crate::arm32::has_neon(), packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 8, 32, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 4, 4, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 8, 32, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 4, 4, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, armv7neon_mmm_i32_8x4, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, armv7neon_mmm_i32_8x4, i8i8:1 } ); MMMExternKernel!(i32, armv7neon_mmm_i32_32x1; 32,1 ; 32, 4; 0, 0; prefetch, crate::arm32::has_neon(), packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 32, 32, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 1, 4, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 32, 32, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 1, 4, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, armv7neon_mmm_i32_32x1, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, armv7neon_mmm_i32_32x1, i8i8:1 } ); sigmoid_impl!(f32, armv7neon_sigmoid_f32_4n, 4, 4, crate::arm32::has_neon()); diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs index 9a3d4d8628..96de1a94ec 100644 --- a/linalg/src/arm64/arm64fp16.rs +++ b/linalg/src/arm64/arm64fp16.rs @@ -18,5 +18,16 @@ MMMExternKernel!(f16, arm64fp16_mmm_f16_32x4_a55; 32, 4; 16, 16; 1, 1; no_prefet MMMExternKernel!(f16, arm64fp16_mmm_f16_128x1_gen; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); MMMExternKernel!(f16, arm64fp16_mmm_f16_128x1_a55; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMExternKernel!(f16, arm64fp16_mmm_f16_64x1_gen; 64, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16(), + packing_defs: { + use crate::frame::block_quant::*; + const PQ40_R64: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 64); + const F16_B: PackedFormat = PackedFormat::new(DatumType::F16, 1, 2, 0); + const PQ40_F16: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&PQ40_R64, &F16_B); + }, + packings: PQ40_F16, + test: mmm_packed_packed_tests!{ crate::arm64::has_fp16(), arm64fp16_mmm_f16_64x1_gen, q40f16:1 } +); + tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index 77b6744192..465faa747c 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -32,22 +32,22 @@ MMMExternKernel!(f32, arm64simd_mmm_f32_64x1_gen; 64, 1; 16, 16; 1, 1; no_prefet MMMExternKernel!(i32, arm64simd_mmm_i32_8x8; 8, 8; 16, 16; 0,0; no_prefetch, true, packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 8, 16, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 8, 16, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 8, 16, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 8, 16, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, arm64simd_mmm_i32_8x8, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, arm64simd_mmm_i32_8x8, i8i8:1 } ); MMMExternKernel!(i32, arm64simd_mmm_i32_64x1; 64, 1; 16, 1; 0,0; no_prefetch, true, packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 64, 16, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 1, 1, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 64, 16, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 1, 1, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, arm64simd_mmm_i32_64x1, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, arm64simd_mmm_i32_64x1, i8i8:1 } ); tanh_impl!(f32, arm64simd_tanh_f32_4n, 4, 4, true); diff --git a/linalg/src/frame.rs b/linalg/src/frame.rs index 7f5c3b70f8..bb7b550793 100644 --- a/linalg/src/frame.rs +++ b/linalg/src/frame.rs @@ -1,6 +1,7 @@ #[macro_use] +pub mod block_quant; +#[macro_use] pub mod element_wise; - #[macro_use] pub mod by_scalar; #[macro_use] @@ -17,7 +18,7 @@ pub mod sigmoid; pub mod tanh; pub mod element_wise_helper; -pub use mmm::pack::Packer; +pub use mmm::pack::PackedFormat; pub use mmm::pack::PackingWriter; pub use self::element_wise::{ElementWise, ElementWiseImpl}; diff --git a/linalg/src/frame/block_quant/helpers.rs b/linalg/src/frame/block_quant/helpers.rs new file mode 100644 index 0000000000..1158880a6e --- /dev/null +++ b/linalg/src/frame/block_quant/helpers.rs @@ -0,0 +1,65 @@ +use byteorder::{ReadBytesExt, WriteBytesExt, LE}; +use std::io::{Cursor, Read, Write}; +use tract_data::internal::*; + +pub struct NibbleReader { + second_half: Option, + reader: R, +} + +impl<'s> NibbleReader> { + pub fn for_slice(slice: &'s [u8]) -> Self { + NibbleReader::new(Cursor::new(slice)) + } +} + +impl NibbleReader { + pub fn new(reader: R) -> NibbleReader { + NibbleReader { reader, second_half: None } + } + + pub fn read_f16(&mut self) -> f16 { + assert!(self.second_half.is_none()); + f16::from_bits(self.reader.read_u16::().unwrap()) + } + + pub fn read_i4(&mut self) -> i8 { + if let Some(second) = self.second_half.take() { + second + } else { + let byte = self.reader.read_u8().unwrap(); + self.second_half = Some((byte >> 4) as i8); + (byte & 0x0F) as i8 + } + } +} + +pub struct NibbleWriter { + first_half: Option, + writer: W, +} + +impl<'s> NibbleWriter> { + pub fn for_slice(slice: &'s mut [u8]) -> Self { + NibbleWriter::new(Cursor::new(slice)) + } +} + +impl NibbleWriter { + pub fn new(writer: W) -> NibbleWriter { + NibbleWriter { writer, first_half: None } + } + + pub fn write_f16(&mut self, f: f16) { + assert!(self.first_half.is_none()); + self.writer.write_u16::(f.to_bits()).unwrap() + } + + pub fn write_i4(&mut self, q: i8) { + if let Some(first) = self.first_half.take() { + self.writer.write_u8(first as u8 | ((q as u8) << 4)).unwrap() + } else { + self.first_half = Some(q); + } + } +} diff --git a/linalg/src/frame/block_quant/mod.rs b/linalg/src/frame/block_quant/mod.rs new file mode 100644 index 0000000000..09bcf9285c --- /dev/null +++ b/linalg/src/frame/block_quant/mod.rs @@ -0,0 +1,211 @@ +use downcast_rs::{impl_downcast, Downcast}; +use dyn_clone::DynClone; +use dyn_hash::DynHash; +use tract_data::internal::*; +use tract_data::itertools::Itertools; + +use std::alloc::Layout; +use std::borrow::Cow; +use std::fmt::{Debug, Display}; +use std::hash::Hash; +use std::ops::Deref; + +mod helpers; +mod q4_0; +mod repack; + +pub use helpers::{NibbleReader, NibbleWriter}; +pub use q4_0::Q4_0; +pub use repack::RepackingPackedBlockQuantValue; + +use crate::mmm::{EagerPackedInput, MMMInputFormat}; + +use super::PackedFormat; + +pub trait BlockQuant: Debug + Display + Send + Sync + DynClone + DynHash + Downcast { + fn same_as(&self, other: &dyn BlockQuant) -> bool; + + fn block_len(&self) -> usize; + + fn block_bytes(&self) -> usize; + + fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]); + fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]); + fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]); + fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]); + + fn quant_f16(&self, input: &[f16]) -> TractResult { + unsafe { + let blocks = input.len() / self.block_len(); + let mut quant = Blob::for_layout( + Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(), + ); + for b in 0..blocks { + let block = &input[b * self.block_len()..][..self.block_len()]; + let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()]; + self.quant_block_f16(block, qblock); + } + Ok(quant) + } + } + + fn quant_f32(&self, input: &[f32]) -> TractResult { + unsafe { + let blocks = input.len() / self.block_len(); + let mut quant = Blob::for_layout( + Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(), + ); + for b in 0..blocks { + let block = &input[b * self.block_len()..][..self.block_len()]; + let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()]; + self.quant_block_f32(block, qblock); + } + Ok(quant) + } + } + + fn dequant_f32(&self, input: &[u8]) -> TractResult { + unsafe { + let blocks = input.len() / self.block_bytes(); + let mut tensor = Tensor::uninitialized::(&[blocks * self.block_len()])?; + let slice = tensor.as_slice_mut::()?; + for b in 0..blocks { + let block = &mut slice[b * self.block_len()..][..self.block_len()]; + let qblock = &input[b * self.block_bytes()..][..self.block_bytes()]; + self.dequant_block_f32(qblock, block); + } + Ok(tensor) + } + } + + fn dequant_f16(&self, input: &[u8]) -> TractResult { + unsafe { + let blocks = input.len() / self.block_bytes(); + let mut tensor = Tensor::uninitialized::(&[blocks * self.block_len()])?; + let slice = tensor.as_slice_mut::()?; + for b in 0..blocks { + let block = &mut slice[b * self.block_len()..][..self.block_len()]; + let qblock = &input[b * self.block_bytes()..][..self.block_bytes()]; + self.dequant_block_f16(qblock, block); + } + Ok(tensor) + } + } + + fn pack(&self, input: &[u8], k: usize, r: usize) -> TractResult; + + unsafe fn repack_panel( + &self, + value: &EagerPackedInput, + target: &PackedFormat, + panel: usize, + scratch: *mut u8, + ) -> TractResult<()>; +} + +dyn_clone::clone_trait_object!(BlockQuant); +dyn_hash::hash_trait_object!(BlockQuant); +impl_downcast!(BlockQuant); + +#[derive(Clone, Hash)] +pub enum StaticBlockQuant { + Owned(Box), + Borrow(&'static dyn BlockQuant), +} + +impl Deref for StaticBlockQuant { + type Target = dyn BlockQuant; + fn deref(&self) -> &dyn BlockQuant { + match self { + StaticBlockQuant::Owned(o) => &**o, + StaticBlockQuant::Borrow(o) => *o, + } + } +} + +#[derive(Clone, Hash)] +pub struct PackedBlockQuantFormat { + pub bq: StaticBlockQuant, + pub r: usize, +} + +impl Display for PackedBlockQuantFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Packed{} (r={})", &*self.bq, self.r) + } +} + +impl Debug for PackedBlockQuantFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl PackedBlockQuantFormat { + pub const fn new(bq: &'static dyn BlockQuant, r: usize) -> Self { + PackedBlockQuantFormat { bq: StaticBlockQuant::Borrow(bq), r } + } + + #[cfg(test)] + pub fn simulate_precision_loss( + &self, + mut tensor: Tensor, + block_axis: usize, + ) -> TractResult { + ensure!(block_axis == tensor.rank() - 1); + ensure!(tensor.shape()[block_axis] % self.bq.block_len() == 0); + let mut scratch = vec![0u8; self.bq.block_bytes()]; + if tensor.datum_type() == f32::datum_type() { + for block in tensor.as_slice_mut::()?.chunks_mut(self.bq.block_len()) { + self.bq.quant_block_f32(block, &mut scratch); + self.bq.dequant_block_f32(&scratch, block); + } + Ok(tensor) + } else if tensor.datum_type() == f16::datum_type() { + for block in tensor.as_slice_mut::()?.chunks_mut(self.bq.block_len()) { + self.bq.quant_block_f16(block, &mut scratch); + self.bq.dequant_block_f16(&scratch, block); + } + Ok(tensor) + } else { + todo!() + } + } +} + +impl MMMInputFormat for PackedBlockQuantFormat { + fn can_prepare_types(&self) -> Vec { + vec![] + } + + fn prepare_tensor( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + assert!(k % self.bq.block_len() == 0); + let t: Cow = if k_axis == 1 && mn_axis == 0 { + Cow::Borrowed(t) + } else { + Cow::Owned(t.clone().move_axis(1, 0)?) + }; + let quant = if t.datum_type() == f32::datum_type() { + self.bq.quant_f32(t.as_slice()?)? + } else if t.datum_type() == f16::datum_type() { + self.bq.quant_f16(t.as_slice()?)? + } else { + todo!() + }; + Ok(Box::new(self.bq.pack(&quant, k, self.r)?)) + } + + fn k_alignment(&self) -> usize { + self.bq.block_len() + } + + fn r(&self) -> usize { + self.r + } +} diff --git a/linalg/src/frame/block_quant/q4_0.rs b/linalg/src/frame/block_quant/q4_0.rs new file mode 100644 index 0000000000..6210f7cac6 --- /dev/null +++ b/linalg/src/frame/block_quant/q4_0.rs @@ -0,0 +1,292 @@ +use super::*; +use num_traits::{AsPrimitive, Float, Zero}; +use std::alloc::Layout; + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct BaseQ4_0; + +pub const Q4_0: BaseQ4_0 = BaseQ4_0::<32>; + +impl BaseQ4_0 { + fn quant_block(&self, block: &[T], quant: &mut [u8]) + where + f32: AsPrimitive, + T: Debug + Float + AsPrimitive + AsPrimitive + 'static, + { + assert!(quant.len() == self.block_bytes()); + assert!(block.len() == self.block_len()); + let mut writer = NibbleWriter::for_slice(quant); + let mut amax = T::zero(); + let mut max = T::zero(); + for v in block { + if amax < v.abs() { + amax = v.abs(); + max = *v; + } + } + let scale: T = max / (-8f32).as_(); + let r_scale = if scale.is_zero() { T::zero() } else { scale.recip() }; + writer.write_f16(scale.as_()); + + for x in block { + let i: i8 = (*x * r_scale + (8.5f32).as_()).as_(); + writer.write_i4(i.min(15)); + } + } + + fn dequant_block(&self, quant: &[u8], block: &mut [T]) + where + f16: AsPrimitive, + i8: AsPrimitive, + { + assert!(quant.len() == self.block_bytes()); + assert!(block.len() == self.block_len()); + let mut nibbles = NibbleReader::for_slice(quant); + let d: T = nibbles.read_f16().as_(); + for x in block { + *x = (nibbles.read_i4() - 8).as_() * d; + } + } + + unsafe fn repack_panel_t( + &self, + value: &EagerPackedInput, + target: &PackedFormat, + panel: usize, + scratch: *mut u8, + ) -> TractResult<()> + where + f16: AsPrimitive, + i8: AsPrimitive, + { + ensure!(value.format.r() == target.r); + ensure!(value.k % self.block_len() == 0); + let scratch = std::slice::from_raw_parts_mut(scratch as *mut T, value.k * target.r); + let blocks_for_k = value.k / self.block_len(); + let row_bytes = blocks_for_k * self.block_bytes(); + let mut input = NibbleReader::for_slice(&value.packed[panel * target.r * row_bytes..]); + let mut scales = vec![T::zero(); target.r]; + let mut scratch = scratch.iter_mut(); + for _ in 0..blocks_for_k { + for s in &mut scales { + *s = input.read_f16().as_(); + } + for _ in 0..self.block_len() { + for &s in &scales { + *scratch.next().unwrap() = s * (input.read_i4() - 8).as_(); + } + } + } + Ok(()) + } +} + +impl BlockQuant for BaseQ4_0 { + fn same_as(&self, other: &dyn BlockQuant) -> bool { + other.downcast_ref::().map(|other| other == self).unwrap_or(false) + } + + fn block_len(&self) -> usize { + QK + } + + fn block_bytes(&self) -> usize { + 2 + self.block_len() / 2 + } + + fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]) { + self.quant_block(block, quant) + } + + fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]) { + self.quant_block(block, quant) + } + + fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]) { + self.dequant_block(quant, block) + } + + fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]) { + self.dequant_block(quant, block) + } + + // s0_0 n0_0 n0_1 n0_2 n0_3 ... n0_30n0_31 s0_32 n0_32n0_33 ... + // s1_0 n1_0 n1_1 n1_2 n1_3 ... n1_30n1_31 s1_32 n1_32n1_33 ... + // + // becomes (with r=4) + // + // s0_0 S1_0 S2_0 s3_0 n0_0 n1_0 n2_0 n3_0 n0_1 n1_1 n2_1 n3_1 ... n0_33 n1_33 n2_33 n3_33 + // s0_32 S1_32 S2_32 s3_32 n0_0 n1_0 n2_0 n3_0 n0_1 n1_1 n2_1 n3_1 ... n0_33 n1_33 n2_33 n3_33 + // ... + fn pack(&self, input: &[u8], k: usize, r: usize) -> TractResult { + assert!(input.len() % self.block_bytes() == 0); + assert!(k % self.block_len() == 0); + let m = if input.len() == 0 { + 0 + } else { + input.len() / self.block_bytes() * self.block_len() / k + }; + let full_panels = m / r; + let panels = m.divceil(r); + let blocks_for_k = k / self.block_len(); + let row_bytes = blocks_for_k * self.block_bytes(); + let panel_bytes = row_bytes * r; + let mut blob = + unsafe { Blob::for_layout(Layout::from_size_align(panel_bytes * panels, 128)?) }; + let mut writer = NibbleWriter::for_slice(&mut blob); + for p in 0..full_panels { + let input = &input[r * p * row_bytes..]; + let mut readers = + (0..r).map(|r| NibbleReader::for_slice(&input[r * row_bytes..])).collect_vec(); + for _ in 0..blocks_for_k { + for reader in &mut readers { + let scale = reader.read_f16(); + writer.write_f16(scale); + } + for _ in 0..self.block_len() { + for r in &mut readers { + let nib = r.read_i4(); + writer.write_i4(nib); + } + } + } + } + let partial_panel = m % r; + if partial_panel > 0 { + let input = &input[r * full_panels * row_bytes..]; + let mut readers = (0..partial_panel) + .map(|r| NibbleReader::for_slice(&input[r * row_bytes..])) + .collect_vec(); + for _ in 0..blocks_for_k { + readers.iter_mut().for_each(|r| writer.write_f16(r.read_f16())); + (partial_panel..r).for_each(|_| writer.write_f16(f16::zero())); + for _ in 0..self.block_len() { + for r in &mut readers { + writer.write_i4(r.read_i4()); + } + (partial_panel..r).for_each(|_| writer.write_i4(0)); + } + } + } + Ok(EagerPackedInput { + format: Box::new(PackedBlockQuantFormat { + bq: StaticBlockQuant::Owned(Box::new(*self)), + r, + }), + packed: blob, + mn: m, + k, + panel_bytes, + }) + } + + unsafe fn repack_panel( + &self, + value: &EagerPackedInput, + target: &PackedFormat, + panel: usize, + scratch: *mut u8, + ) -> TractResult<()> { + dispatch_floatlike!(Self::repack_panel_t(target.dt)(self, value, target, panel, scratch)) + } +} + +impl Display for BaseQ4_0 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Q4_0") + } +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + use tract_data::internal::tract_ndarray::Array2; + + use crate::frame::PackedFormat; + + use super::*; + + fn cycle_f32(b: impl BlockQuant, data: &[f32]) { + let mut input = data.to_vec(); + while input.len() % b.block_len() != 0 { + input.push(0f32); + } + let quant = b.quant_f32(&input).unwrap(); + let result = b.dequant_f32(&quant).unwrap(); + let view = &result.as_slice::().unwrap()[..data.len()]; + assert_eq!(data, view); + } + + fn cycle_f16(b: impl BlockQuant, data: &[f32]) { + let mut input = data.iter().map(|f| f16::from_f32(*f)).collect_vec(); + while input.len() % b.block_len() != 0 { + input.push(f16::zero()); + } + let quant = b.quant_f16(&input).unwrap(); + let result = b.dequant_f16(&quant).unwrap(); + let view = &result.as_slice::().unwrap(); + assert_eq!(&input, view); + } + + #[test] + fn loop_q4f32_pos() { + cycle_f32(Q4_0, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn loop_q4f16_pos() { + cycle_f16(Q4_0, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn loop_q4f32_neg() { + cycle_f32(Q4_0, &[-1.0, -2.0, -3.0, -4.0]); + } + + #[test] + fn loop_q4f16_beg() { + cycle_f16(Q4_0, &[-1.0, -2.0, -3.0, -4.0]); + } + + #[test] + fn loop_q4_big_pos() { + cycle_f32(Q4_0, &[1234.0]); + cycle_f16(Q4_0, &[-1.0, -2.0, -3.0, -4.0]); + } + + #[test] + fn loop_q4_big_neg() { + cycle_f32(Q4_0, &[-1234.0]); + cycle_f16(Q4_0, &[-1234.0]); + } + + #[test] + fn packing() -> TractResult<()> { + let (q, k, m, r) = (BaseQ4_0::<2>, 4, 4, 2); + let weights_orig = + Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.) + .into_tensor(); + let weights_f32 = + q.dequant_f32(&q.quant_f32(weights_orig.as_slice::()?)?)?.into_shape(&[m, k])?; + eprintln!("{:?}", weights_f32.to_array_view::()?); + let packer = PackedFormat::new(f32::datum_type(), r, 128, 0); + let packed_f32 = packer.pack_tensor(&weights_f32, 1, 0)?; + assert_eq!(packed_f32.panels_count(), 2); + + let q4 = q.quant_f32(&weights_f32.as_slice::()?)?; + let packed_q4 = q.pack(&q4, k, r)?; + + for panel in 0..2 { + unsafe { + let panel_f32 = packed_f32.panel_bytes(panel, None)?; + let panel_f32 = std::slice::from_raw_parts(panel_f32 as *const f32, k * r); + eprintln!("{panel_f32:?}"); + let mut panel_q4 = Tensor::zero::(&[k * r])?; + q.repack_panel(&packed_q4, &packer, panel, panel_q4.as_bytes_mut().as_mut_ptr())?; + eprintln!("{panel_q4:?}"); + assert_eq!(panel_q4.as_slice::()?, panel_f32); + } + } + Ok(()) + } +} diff --git a/linalg/src/frame/block_quant/repack.rs b/linalg/src/frame/block_quant/repack.rs new file mode 100644 index 0000000000..7d07623f20 --- /dev/null +++ b/linalg/src/frame/block_quant/repack.rs @@ -0,0 +1,45 @@ +use crate::mmm::MMMInputValue; + +use super::*; +use crate::frame::PackedFormat; + +#[derive(Clone, Hash)] +pub struct RepackingPackedBlockQuantValue { + pub value: EagerPackedInput, + pub pack: PackedFormat, +} + +impl Display for RepackingPackedBlockQuantValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?} repacked to {:?}", self.value, self.pack) + } +} + +impl Debug for RepackingPackedBlockQuantValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl MMMInputValue for RepackingPackedBlockQuantValue { + fn scratch_panel_buffer_layout(&self) -> Option { + Some(self.pack.single_panel_layout(self.value.k, 4)) + } + fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> { + let buffer = buffer.context("Scratch panel expected")?; + let pbqf = self.value.format.downcast_ref::().unwrap(); + unsafe { + pbqf.bq.repack_panel(&self.value, &self.pack, i, buffer)?; + } + Ok(buffer) + } + fn mn(&self) -> usize { + self.value.mn + } + fn r(&self) -> usize { + self.pack.r + } + fn k(&self) -> usize { + self.value.k + } +} diff --git a/linalg/src/frame/mmm/fuse.rs b/linalg/src/frame/mmm/fuse.rs index 0b37e2fc18..8ad0e10f03 100644 --- a/linalg/src/frame/mmm/fuse.rs +++ b/linalg/src/frame/mmm/fuse.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use super::{MMMInput, OutputStore, OutputStoreKer}; +use super::{MMMInputValue, OutputStore, OutputStoreKer}; use tract_data::internal::*; #[repr(usize)] @@ -48,7 +48,7 @@ pub enum FusedSpec<'t> { RoundingShiftRight(usize, RoundingPolicy), ShiftLeft(usize), Store(OutputStore), - AddMatMul { a: &'t dyn MMMInput, b: &'t dyn MMMInput, packing: usize }, + AddMatMul { a: &'t dyn MMMInputValue, b: &'t dyn MMMInputValue, packing: usize }, } impl<'t> FusedSpec<'t> { diff --git a/linalg/src/frame/mmm/input_store.rs b/linalg/src/frame/mmm/input_store.rs index ded3300cc7..ea805f8c0b 100644 --- a/linalg/src/frame/mmm/input_store.rs +++ b/linalg/src/frame/mmm/input_store.rs @@ -1,4 +1,5 @@ use downcast_rs::{impl_downcast, Downcast}; +use dyn_clone::DynClone; use dyn_hash::DynHash; use std::alloc::Layout; use std::fmt::{Debug, Display}; @@ -6,20 +7,24 @@ use std::hash::Hash; use std::sync::Arc; use tract_data::internal::*; -pub trait MMMInputFormat: Downcast + Debug { +pub trait MMMInputFormat: Downcast + Debug + DynHash + DynClone + Send + Sync + Display { fn can_prepare_types(&self) -> Vec; fn prepare_tensor( &self, t: &Tensor, k_axis: usize, mn_axis: usize, - ) -> TractResult>; + ) -> TractResult>; + fn r(&self) -> usize; + fn k_alignment(&self) -> usize; } +dyn_clone::clone_trait_object!(MMMInputFormat); impl_downcast!(MMMInputFormat); +dyn_hash::hash_trait_object!(MMMInputFormat); -pub trait MMMInput: dyn_clone::DynClone + Debug + DynHash + Send + Sync + Display { +pub trait MMMInputValue: DynClone + Debug + DynHash + Send + Sync + Display { fn scratch_panel_buffer_layout(&self) -> Option; - fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> *const u8; + fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8>; fn panels_count(&self) -> usize { self.mn().divceil(self.r()) } @@ -27,32 +32,32 @@ pub trait MMMInput: dyn_clone::DynClone + Debug + DynHash + Send + Sync + Displa fn r(&self) -> usize; fn k(&self) -> usize; } -dyn_clone::clone_trait_object!(MMMInput); -dyn_hash::hash_trait_object!(MMMInput); +dyn_clone::clone_trait_object!(MMMInputValue); +dyn_hash::hash_trait_object!(MMMInputValue); -impl From> for Opaque { - fn from(value: Box) -> Self { +impl From> for Opaque { + fn from(value: Box) -> Self { Opaque(Arc::new(value)) } } -impl OpaquePayload for Box {} +impl OpaquePayload for Box {} -#[derive(Debug, Clone, Hash)] +#[derive(Clone, Hash)] pub struct EagerPackedInput { - pub packed: Tensor, + pub format: Box, + pub packed: Blob, pub panel_bytes: usize, pub mn: usize, - pub r: usize, pub k: usize, } -impl MMMInput for EagerPackedInput { +impl MMMInputValue for EagerPackedInput { fn scratch_panel_buffer_layout(&self) -> Option { None } - fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> *const u8 { - unsafe { self.packed.as_ptr_unchecked::().add(i * self.panel_bytes) } + fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> TractResult<*const u8> { + unsafe { Ok(self.packed.as_ptr().add(i * self.panel_bytes)) } } fn k(&self) -> usize { self.k @@ -61,12 +66,18 @@ impl MMMInput for EagerPackedInput { self.mn } fn r(&self) -> usize { - self.r + self.format.r() } } impl Display for EagerPackedInput { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Eagerly packed tensor") + write!(f, "Eager {} tensor (mn={} k={})", self.format, self.mn(), self.k()) + } +} + +impl Debug for EagerPackedInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) } } diff --git a/linalg/src/frame/mmm/macros.rs b/linalg/src/frame/mmm/macros.rs index 23d17cd71e..5dce77104f 100644 --- a/linalg/src/frame/mmm/macros.rs +++ b/linalg/src/frame/mmm/macros.rs @@ -12,7 +12,7 @@ macro_rules! MMMExternKernel { use crate::frame::mmm::*; extern_kernel!(fn $func(op: *const FusedKerSpec<$ti>) -> isize); } - MMMKernelWrapper!($ti, $func; []::$func; $mr, $nr; $alignment_bytes_packed_a, $alignment_bytes_packed_b; $end_padding_packed_a, $end_padding_packed_b; $prefetch, $cond + MMMKernelWrapper!($ti, $func; []::$func; $mr, $nr; $alignment_bytes_packed_a, $alignment_bytes_packed_b; $end_padding_packed_a, $end_padding_packed_b; $prefetch, $cond $(, can_fuse: $can_fuse )? $(, packing_defs: { $($packing_def)* } )? $(, packings: $($packing)* )? @@ -37,7 +37,7 @@ macro_rules! MMMKernelWrapper { paste! { #[allow(non_camel_case_types)] - #[derive(Copy, Clone, Debug, new, Default)] + #[derive(Copy, Clone, new, Default)] pub struct [<$id:camel>]; impl [<$id:camel>] { @@ -46,13 +46,19 @@ macro_rules! MMMKernelWrapper { } } + impl std::fmt::Debug for [<$id:camel>] { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!($id)) + } + } + mod [] { - use $crate::frame::mmm::pack::Packer; + use $crate::frame::mmm::pack::PackedFormat; use $crate::frame::mmm::MMMInputFormat; use tract_data::prelude::*; - const NATIVE_A: Packer = Packer::new(DatumType::[<$ti:upper>], $mr, $alignment_bytes_packed_a, $end_padding_packed_a); - const NATIVE_B: Packer = Packer::new(DatumType::[<$ti:upper>], $nr, $alignment_bytes_packed_b, $end_padding_packed_b); + const NATIVE_A: PackedFormat = PackedFormat::new(DatumType::[<$ti:upper>], $mr, $alignment_bytes_packed_a, $end_padding_packed_a); + const NATIVE_B: PackedFormat = PackedFormat::new(DatumType::[<$ti:upper>], $nr, $alignment_bytes_packed_b, $end_padding_packed_b); const NATIVE: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&NATIVE_A, &NATIVE_B); $($($packing_def)*)? pub const PACKINGS: &[(&dyn MMMInputFormat, &dyn MMMInputFormat)] = &[NATIVE $( $(, $packing)* )?]; diff --git a/linalg/src/frame/mmm/pack.rs b/linalg/src/frame/mmm/pack.rs index b201798cd5..f8a435dcc5 100644 --- a/linalg/src/frame/mmm/pack.rs +++ b/linalg/src/frame/mmm/pack.rs @@ -1,24 +1,24 @@ use std::alloc::Layout; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use std::marker::PhantomData; use std::ops::Range; use tract_data::internal::*; -use crate::mmm::{EagerPackedInput, MMMInput}; +use crate::mmm::{EagerPackedInput, MMMInputValue}; use super::MMMInputFormat; -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct Packer { +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct PackedFormat { pub dt: DatumType, pub r: usize, pub alignment: usize, pub end_padding_record: usize, } -impl MMMInputFormat for Packer { +impl MMMInputFormat for PackedFormat { fn can_prepare_types(&self) -> Vec { - vec!(self.dt) + vec![self.dt] } fn prepare_tensor( @@ -26,14 +26,39 @@ impl MMMInputFormat for Packer { t: &Tensor, k_axis: usize, mn_axis: usize, - ) -> TractResult> { - Packer::pack_tensor(self, t, k_axis, mn_axis) + ) -> TractResult> { + PackedFormat::pack_tensor(self, t, k_axis, mn_axis) + } + + fn r(&self) -> usize { + self.r + } + + fn k_alignment(&self) -> usize { + 1 + } +} + +impl Display for PackedFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Packed{:?}(r={})", self.dt, self.r) } } -impl Packer { - pub const fn new(dt: DatumType, nr: usize, alignment: usize, end_padding_record: usize) -> Packer { - Packer { dt, r: nr, alignment, end_padding_record } +impl Debug for PackedFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl PackedFormat { + pub const fn new( + dt: DatumType, + nr: usize, + alignment: usize, + end_padding_record: usize, + ) -> PackedFormat { + PackedFormat { dt, r: nr, alignment, end_padding_record } } #[inline] @@ -67,7 +92,8 @@ impl Packer { t: &Tensor, k_axis: usize, mn_axis: usize, - ) -> TractResult> { + ) -> TractResult> { + ensure!(t.datum_type().is_copy()); ensure!(t.datum_type().unquantized() == self.dt.unquantized()); let k = t.shape()[k_axis]; let mn = t.shape()[mn_axis]; @@ -77,10 +103,13 @@ impl Packer { let strides = t.strides(); unsafe { let mut packed = - Tensor::uninitialized_aligned_dt(t.datum_type(), &[packed_len], self.alignment())?; + Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment); + if cfg!(debug_assertions) { + packed.as_bytes_mut().fill(0u8); + } dispatch_copy!(Self::pack_t(t.datum_type())( self, - packed.as_ptr_mut_unchecked(), + packed.as_mut_ptr() as _, t.as_ptr_unchecked(), mn, strides[k_axis], @@ -88,7 +117,13 @@ impl Packer { 0..k, 0..mn )); - Ok(Box::new(EagerPackedInput { packed, panel_bytes, mn, k, r: self.r })) + Ok(Box::new(EagerPackedInput { + format: Box::new(self.clone()), + packed, + panel_bytes, + mn, + k, + })) } } @@ -97,7 +132,7 @@ impl Packer { t: &TensorView, k_axis: usize, mn_axis: usize, - ) -> TractResult> { + ) -> TractResult> { ensure!(t.datum_type().unquantized() == self.dt.unquantized()); let k = t.shape()[k_axis]; let mn = t.shape()[mn_axis]; @@ -107,10 +142,13 @@ impl Packer { let strides = t.strides(); unsafe { let mut packed = - Tensor::uninitialized_aligned_dt(t.datum_type(), &[packed_len], self.alignment())?; + Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment); + if cfg!(debug_assertions) { + packed.as_bytes_mut().fill(0u8); + } dispatch_copy!(Self::pack_t(t.datum_type())( self, - packed.as_ptr_mut_unchecked(), + packed.as_mut_ptr() as _, t.as_ptr_unchecked(), mn, strides[k_axis], @@ -118,7 +156,13 @@ impl Packer { 0..k, 0..mn )); - Ok(Box::new(EagerPackedInput { packed, panel_bytes, mn, k, r: self.r })) + Ok(Box::new(EagerPackedInput { + format: Box::new(self.clone()), + packed, + panel_bytes, + mn, + k, + })) } } @@ -147,6 +191,9 @@ impl Packer { k_range: Range, mn_range: Range, ) { + if k_range.len() == 0 || mn_range.len() == 0 { + return + } if self.r == 1 && k_stride == 1 && mn == 1 { pb.copy_from_nonoverlapping(b.add(k_range.start), k_range.len()) } else if mn_stride == 1 { @@ -493,7 +540,7 @@ mod test { fn packer(&self) -> Array2 { let panels = self.mn_range.len().divceil(self.r); - let packer = super::Packer::new(u32::datum_type(), self.r, self.align_panel, 0); + let packer = super::PackedFormat::new(u32::datum_type(), self.r, self.align_panel, 0); let input = self.input().into_tensor(); let panel_len = packer.single_panel_len(self.k_range.len()); let mut output = diff --git a/linalg/src/frame/mmm/scratch.rs b/linalg/src/frame/mmm/scratch.rs index df3f511fa7..aa5185501f 100644 --- a/linalg/src/frame/mmm/scratch.rs +++ b/linalg/src/frame/mmm/scratch.rs @@ -268,12 +268,12 @@ impl ScratchSpaceImpl { (tls.blob.as_mut_ptr().add(*loc) as *mut AddMatMulTemp).as_mut().unwrap(); if scratch.panel_a_id != down { scratch.ptr_a = - a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o))); + a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_a_id = down; } if scratch.panel_b_id != right { scratch.ptr_b = - b.panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o))); + b.panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_b_id = right; } FKS::AddMatMul { @@ -442,12 +442,12 @@ impl ScratchSpaceImpl { let scratch = (loc as *mut AddMatMulTemp).as_mut().unwrap(); if scratch.panel_a_id != down { scratch.ptr_a = - a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o))); + a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_a_id = down; } if scratch.panel_b_id != right { scratch.ptr_b = - b.panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o))); + b.panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_b_id = right; } FKS::AddMatMul { diff --git a/linalg/src/frame/mmm/tests/frame.rs b/linalg/src/frame/mmm/tests/frame.rs index ce75971867..89d393ebc5 100644 --- a/linalg/src/frame/mmm/tests/frame.rs +++ b/linalg/src/frame/mmm/tests/frame.rs @@ -1,9 +1,8 @@ +use crate::frame::mmm::*; use crate::LADatum; use num_traits::AsPrimitive; -use crate::frame::mmm::*; -use crate::frame::mmm::pack::Packer; -use proptest::prelude::*; use std::ops::Neg; +use tests::display_error; use tract_data::internal::*; #[macro_export] @@ -15,319 +14,73 @@ macro_rules! mmm_frame_tests { #[allow(unused_imports)] use $crate::frame::mmm::tests::frame::*; - proptest::proptest! { - #[test] - fn mat_mul_prepacked_prop((m, k, n, ref a, ref b) in strat_mat_mat_mul(& $ker, 0)) { - if $cond { - test_mat_mat_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, m, k, n, &a, &b)?; - } - } - - #[test] - fn mat_vec_prepacked_prop((m, k, ref a, ref b) in strat_mat_vec_mul(& $ker, 0)) { - if $cond { - test_mat_vec_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, m, k, &*a, b)? - } - } - } - - #[test] - fn mat_mul_1() { - if $cond { - let a = tensor2(&[[-3i32, 3, 5, -5], [6, 0, -6, -5], [0, 0, 9, 7]]) - .cast_to::<$ta>() - .unwrap() - .into_owned(); - let b = tensor2(&[[-8i32, 5], [5, -3], [5, 7], [-8, -1]]) - .cast_to::<$tb>() - .unwrap() - .into_owned(); - test_mat_mat_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, 3, 4, 2, &a, &b).unwrap() - } - } - - #[test] - fn mat_mul_2() { - if $cond { - let a = tensor2(&[[1i32]]).cast_to::<$ta>().unwrap().into_owned(); - let b = tensor2(&[[0i32, 0, 1]]).cast_to::<$tb>().unwrap().into_owned(); - test_mat_mat_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, 1, 1, 3, &a, &b).unwrap() - } - } - - #[test] - fn mat_mul_3() { - if $cond { - let a = tensor2(&[[-3i32, 3, 5, -5], [6, 0, -6, -5], [0, 0, 9, 7]]) - .cast_to::<$ta>() - .unwrap() - .into_owned(); - let b = tensor2(&[[-8i32, 5], [5, -3], [5, 7], [-8, -1]]) - .cast_to::<$tb>() - .unwrap() - .into_owned(); - test_mat_mat_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, 3, 4, 2, &a, &b).unwrap() - } - } - - #[test] - fn mat_mul_4() { - if $cond { - let a = tensor2(&[[122, 82]]).cast_to::<$ta>().unwrap().into_owned(); - let b = - tensor2(&[[0, 0, 37], [0, 0, 57]]).cast_to::<$tb>().unwrap().into_owned(); - test_mat_mat_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, 1, 2, 3, &a, &b).unwrap() - } - } - - #[test] - fn mat_mul_1_1_1() { - if $cond { - test_mat_mat_mul_prep::<_, $ta, $tb, $tc, $ti>( - $ker, - 1, - 1, - 1, - &tensor2(&[[26]]).cast_to::<$ta>().unwrap(), - &tensor2(&[[48]]).cast_to::<$tb>().unwrap(), - ) - .unwrap() - } - } - - #[test] - fn mat_mul_1_2_1() { - if $cond { - test_mat_mat_mul_prep::<_, $ta, $tb, $tc, $ti>( - $ker, - 1, - 2, - 1, - &tensor2(&[[0, 1]]).cast_to::<$ta>().unwrap(), - &tensor2(&[[0], [1]]).cast_to::<$tb>().unwrap(), - ) - .unwrap() - } - } - - #[test] - fn mat_vec_1() { - if $cond { - let a = tensor2(&[[0], [1]]).cast_to::<$ta>().unwrap().into_owned(); - let b = tensor1(&[1]).cast_to::<$tb>().unwrap().into_owned(); - test_mat_vec_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, 2, 1, &a, &b).unwrap() - } - } - #[test] - fn mat_vec_2() { + fn row_mul_2_1_3() -> TractResult<()> { if $cond { - let a = tensor1(&[0, 0, 0, 0, 0, 0, -4, 1]).into_shape(&[8, 1]).unwrap(); - let a = a.cast_to::<$ta>().unwrap(); - let b = tensor1(&[-64]).cast_to::<$tb>().unwrap().into_owned(); - test_mat_vec_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, 8, 1, &a, &b).unwrap() + unsafe { row_mul::<_, $ta, $tb, $tc, $ti>($ker, 2, 3)? } } + Ok(()) } #[test] - fn mat_vec_3() { + fn row_add_2_1_3() -> TractResult<()> { if $cond { - let a = tensor1(&[0, 0]).into_shape(&[1, 2]).unwrap(); - let a = a.cast_to::<$ta>().unwrap(); - let b = tensor1(&[0, 0]).cast_to::<$tb>().unwrap().into_owned(); - test_mat_vec_mul_prep::<_, $ta, $tb, $tc, $ti>($ker, 1, 2, &a, &b).unwrap() + unsafe { row_add::<_, $ta, $tb, $tc, $ti>($ker, 2, 3)? } } + Ok(()) } #[test] - fn row_mul_2_1_3() { + fn col_mul_2_1_3() -> TractResult<()> { if $cond { - unsafe { row_mul::<_, $ta, $tb, $tc, $ti>($ker, 2, 3).unwrap() } + unsafe { col_mul::<_, $ta, $tb, $tc, $ti>($ker, 2, 3)? } } + Ok(()) } #[test] - fn row_add_2_1_3() { + fn col_add_2_1_3() -> TractResult<()> { if $cond { - unsafe { row_add::<_, $ta, $tb, $tc, $ti>($ker, 2, 3).unwrap() } + unsafe { col_add::<_, $ta, $tb, $tc, $ti>($ker, 2, 3)? } } + Ok(()) } #[test] - fn col_mul_2_1_3() { + fn max_2_1_3() -> TractResult<()> { if $cond { - unsafe { col_mul::<_, $ta, $tb, $tc, $ti>($ker, 2, 3).unwrap() } + unsafe { max::<_, $ta, $tb, $tc, $ti>($ker, 2, 3)? } } + Ok(()) } #[test] - fn col_add_2_1_3() { + fn min_2_1_3() -> TractResult<()> { if $cond { - unsafe { col_add::<_, $ta, $tb, $tc, $ti>($ker, 2, 3).unwrap() } + unsafe { min::<_, $ta, $tb, $tc, $ti>($ker, 2, 3)? } } + Ok(()) } #[test] - fn max_2_1_3() { + fn add_d_2_1_3() -> TractResult<()> { if $cond { - unsafe { max::<_, $ta, $tb, $tc, $ti>($ker, 2, 3).unwrap() } + unsafe { add_d::<_, $ta, $tb, $tc, $ti>($ker, 2, 3)? } } + Ok(()) } #[test] - fn min_2_1_3() { + fn add_d_big() -> TractResult<()> { if $cond { - unsafe { min::<_, $ta, $tb, $tc, $ti>($ker, 2, 3).unwrap() } - } - } - - #[test] - fn add_d_2_1_3() { - if $cond { - unsafe { add_d::<_, $ta, $tb, $tc, $ti>($ker, 2, 3).unwrap() } - } - } - - #[test] - fn add_d_big() { - if $cond { - unsafe { add_d::<_, $ta, $tb, $tc, $ti>($ker, 197, 1).unwrap() } + unsafe { add_d::<_, $ta, $tb, $tc, $ti>($ker, 197, 1)? } } + Ok(()) } } }; } -fn tensor(dt: DatumType, shape: Vec) -> BoxedStrategy { - let len = shape.iter().product::(); - // for f16, positive numbers only to avoid worst rounding side effects - // and not too big either to avoid overflow :) - let number = if dt == f16::datum_type() { - (0i16..100).boxed() - } else { - any::().prop_map(|i| i as i16).boxed() - }; - proptest::collection::vec(number, len..=len) - .prop_map(move |vec| { - tract_ndarray::ArrayD::from_shape_vec(shape.clone(), vec) - .unwrap() - .into_tensor() - .cast_to_dt(dt) - .unwrap() - .into_owned() - }) - .boxed() -} - -pub fn strat_mat_mat_mul( - ker: &dyn MatMatMul, - packing: usize, -) -> BoxedStrategy<(usize, usize, usize, Tensor, Tensor)> { - let dta = ker.packings()[packing].0.downcast_ref::().unwrap().dt; - let dtb = ker.packings()[packing].1.downcast_ref::().unwrap().dt; - (1usize..5, 1usize..5, 1usize..5) - .prop_flat_map(move |(m, k, n)| { - (Just(m), Just(k), Just(n), tensor(dta, vec![m, k]), tensor(dtb, vec![k, n])) - }) - .boxed() -} - -pub fn strat_mat_vec_mul( - ker: &dyn MatMatMul, - packing: usize, -) -> BoxedStrategy<(usize, usize, Tensor, Tensor)> { - let dta = ker.packings()[packing].0.downcast_ref::().unwrap().dt; - let dtb = ker.packings()[packing].1.downcast_ref::().unwrap().dt; - (1usize..15, 1usize..15) - .prop_flat_map(move |(m, k)| { - (Just(m), Just(k), tensor(dta, vec![m, k]), tensor(dtb, vec![k, 1])) - }) - .boxed() -} - -pub fn test_mat_mat_mul_prep + 'static, TA, TB, TC, TI>( - ker: K, - m: usize, - k: usize, - n: usize, - a: &Tensor, - b: &Tensor, -) -> Result<(), proptest::test_runner::TestCaseError> -where - TA: LADatum + AsPrimitive + 'static, - TB: LADatum + AsPrimitive + 'static, - TC: LADatum + AsPrimitive + 'static, - TI: LADatum + AsPrimitive, - i32: AsPrimitive, - usize: AsPrimitive, -{ - crate::setup_test_logger(); - assert_eq!(a.datum_type(), TA::datum_type()); - unsafe { - let packing = 0; - let pack = &ker.packings()[0]; - let packed_a = pack.0.prepare_tensor(a, 1, 0).unwrap(); - let packed_b = pack.1.prepare_tensor(b, 0, 1).unwrap(); - - fused_ops::( - ker, - m, - n, - &[FusedSpec::AddMatMul { a: &*packed_a, b: &*packed_b, packing }], - |r, c| { - let mut v: TI = TI::zero(); - for i in 0..k { - let a: TI = a.as_slice::().unwrap()[i + k * r].as_(); - let b: TI = b.as_slice::().unwrap()[c + i * n].as_(); - v += a * b; - } - v.as_() - }, - ) - } -} - -pub fn test_mat_vec_mul_prep + 'static, TA, TB, TC, TI>( - ker: K, - m: usize, - k: usize, - a: &Tensor, - b: &Tensor, -) -> Result<(), proptest::test_runner::TestCaseError> -where - TA: LADatum + AsPrimitive + 'static, - TB: LADatum + AsPrimitive + 'static, - TC: LADatum + AsPrimitive + 'static, - TI: LADatum + AsPrimitive + 'static + Neg, - i32: AsPrimitive, - usize: AsPrimitive, -{ - crate::setup_test_logger(); - unsafe { - let packing = 0; - let b = b.clone().into_shape(&[k, 1]).unwrap(); - let pack = &ker.packings()[packing]; - let packed_a = pack.0.prepare_tensor(a, 1, 0).unwrap(); - let packed_b = pack.1.prepare_tensor(&b, 0, 1).unwrap(); - - fused_ops::( - ker, - m, - 1, - &[FusedSpec::AddMatMul { a: &*packed_a, b: &*packed_b, packing }], - |r, _| { - let mut inter = TI::zero(); - for i in 0..k { - let a: TI = a.as_slice::().unwrap()[i + k * r].as_(); - let b: TI = b.as_slice::().unwrap()[i].as_(); - inter += a * b; - } - inter.as_() - }, - ) - } -} - pub unsafe fn fused_ops< K: MatMatMulKer + 'static, TA, @@ -341,7 +94,7 @@ pub unsafe fn fused_ops< n: usize, spec: &[FusedSpec], expect: F, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, @@ -352,43 +105,28 @@ where { crate::setup_test_logger(); - let mut found = Tensor::zero::(&[m, n]).unwrap(); + let mut found = Tensor::zero::(&[m, n])?; let c_store = ker .c_from_data_and_strides(TC::datum_type().size_of(), n as isize, 1) .wrap(&found.view_mut()); let mut spec: TVec = spec.into(); spec.push(FusedSpec::Store(c_store)); - ker.run(m, n, &spec).unwrap(); + ker.run(m, n, &spec)?; let expected = tract_ndarray::prelude::Array2::from_shape_fn((m, n), |(r, c)| expect(r, c)).into_tensor(); - if found.close_enough(&expected, true).is_err() { - println!("found, expected:"); - for r in 0..m { - for c in 0..n { - let f = found.as_slice_unchecked::()[r * n + c]; - let e = expected.as_slice_unchecked::()[r * n + c]; - let mut s = format!("{:4} ", f); - if f != e { - s = nu_ansi_term::Color::Red.paint(s).to_string(); - } - print!("{:4} ", s); - } - print!(" "); - for c in 0..n { - print!("{:4} ", expected.as_slice_unchecked::()[r * n + c]); - } - println!(); - } + let err = found.close_enough(&expected, true); + if err.is_err() { + display_error(found.as_slice::()?, expected.as_slice::()?, m, n); } - found.close_enough(&expected, true).map_err(|e| TestCaseError::Fail(e.to_string().into())) + err } pub unsafe fn row_add + 'static, TA, TB, TC, TI>( ker: K, m: usize, n: usize, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, @@ -411,7 +149,7 @@ pub unsafe fn row_mul + 'static, TA, TB, TC, TI>( ker: K, m: usize, n: usize, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, @@ -437,7 +175,7 @@ pub unsafe fn col_add + 'static, TA, TB, TC, TI>( ker: K, m: usize, n: usize, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, @@ -460,7 +198,7 @@ pub unsafe fn col_mul + 'static, TA, TB, TC, TI>( ker: K, m: usize, n: usize, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, @@ -486,7 +224,7 @@ pub unsafe fn add_d + 'static, TA, TB, TC, TI>( ker: K, m: usize, n: usize, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, @@ -496,14 +234,15 @@ where usize: AsPrimitive, { let d = (0..m * n).map(|i| i.as_()).collect::>(); - let d = tensor1(&d).into_shape(&[m, n]).unwrap(); + let d = tensor1(&d).into_shape(&[m, n])?; let store_spec = OutputStoreSpec::View { m_axis: 0, n_axis: 1, mr: ker.mr(), nr: ker.nr() }; + let view_d = d.to_array_view::()?.into_dimensionality()?; fused_ops::( ker, m, n, &[FusedSpec::AddUnicast(store_spec.wrap(&d.view()))], - |r, c| d.to_array_view_unchecked::().into_dimensionality().unwrap()[(r, c)].as_(), + |r, c| view_d[(r, c)].as_(), ) } @@ -511,7 +250,7 @@ pub unsafe fn max, TA, TB, TC, TI>( ker: K, m: usize, n: usize, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, @@ -534,7 +273,7 @@ pub unsafe fn min, TA, TB, TC, TI>( ker: K, m: usize, n: usize, -) -> proptest::test_runner::TestCaseResult +) -> TractResult<()> where TA: LADatum + AsPrimitive + 'static, TB: LADatum + AsPrimitive + 'static, diff --git a/linalg/src/frame/mmm/tests/mod.rs b/linalg/src/frame/mmm/tests/mod.rs index 9418ea2543..07bf0b6a72 100644 --- a/linalg/src/frame/mmm/tests/mod.rs +++ b/linalg/src/frame/mmm/tests/mod.rs @@ -1,3 +1,6 @@ +use proptest::prelude::*; +use tract_data::internal::*; + use crate::LADatum; #[macro_use] @@ -28,7 +31,7 @@ macro_rules! test_mmm_kernel { #[macro_export] macro_rules! test_mmm_kernel_f16 { ($k: ident, $cond: expr) => { - mmm_packed_packed_tests!($cond, $k, f16f16:0, f16, f16, f16, f16); + mmm_packed_packed_tests!($cond, $k, f16f16:0); mmm_frame_tests!($cond, $k, f16, f16, f16, f16); mmm_kernel_fuse_tests!($cond, $k, f16, f16); }; @@ -37,7 +40,7 @@ macro_rules! test_mmm_kernel_f16 { #[macro_export] macro_rules! test_mmm_kernel_f32 { ($k: ident, $cond: expr) => { - mmm_packed_packed_tests!($cond, $k, f32f32:0, f32, f32, f32, f32); + mmm_packed_packed_tests!($cond, $k, f32f32:0); mmm_frame_tests!($cond, $k, f32, f32, f32, f32); mmm_kernel_fuse_tests!($cond, $k, f32, f32); }; @@ -46,7 +49,7 @@ macro_rules! test_mmm_kernel_f32 { #[macro_export] macro_rules! test_mmm_kernel_f64 { ($k: ident, $cond: expr) => { - mmm_packed_packed_tests!($cond, $k, f64f64:0, f64, f64, f64, f64); + mmm_packed_packed_tests!($cond, $k, f64f64:0); mmm_frame_tests!($cond, $k, f64, f64, f64, f64); mmm_kernel_fuse_tests!($cond, $k, f64, f64); }; @@ -55,7 +58,7 @@ macro_rules! test_mmm_kernel_f64 { #[macro_export] macro_rules! test_mmm_kernel_i32 { ($k: ident, $cond: expr) => { - mmm_packed_packed_tests!($cond, $k, i32i32:0, i32, i32, i32, i32); + mmm_packed_packed_tests!($cond, $k, i32i32:0); mmm_kernel_fuse_tests!($cond, $k, i32, i32); mmm_frame_tests!($cond, $k, i32, i32, i32, i32); mmm_q_scale_tests!($cond, $k); @@ -64,22 +67,19 @@ macro_rules! test_mmm_kernel_i32 { fn display_error(v: &[TC], expected: &[TC], m: usize, n: usize) { if v != expected { - println!("found, expected:"); - for m in 0..m { - for n in 0..n { + for ixm in 0..m { + for ixn in 0..n { use nu_ansi_term::Color::*; - let f = v[m * n + n]; - let e = expected[m * n + n]; + let f = v[ixm * n + ixn]; + let e = expected[ixm * n + ixn]; let color = if f != e { Red } else { Green }; print!("{} ", color.paint(format!("{:4}", f))); } print!(" "); - for n in 0..n { - print!("{:4} ", expected[m * n + n]); + for ixn in 0..n { + print!("{:4} ", expected[ixm * n + ixn]); } println!(); } } } - - diff --git a/linalg/src/frame/mmm/tests/packed_packed.rs b/linalg/src/frame/mmm/tests/packed_packed.rs index dfe76811c7..ae48fdef2a 100644 --- a/linalg/src/frame/mmm/tests/packed_packed.rs +++ b/linalg/src/frame/mmm/tests/packed_packed.rs @@ -1,87 +1,155 @@ +use crate::frame::block_quant::PackedBlockQuantFormat; use crate::frame::mmm::*; -use crate::frame::Packer; -use crate::LADatum; -use num_traits::{AsPrimitive, One, Zero}; +use pack::PackedFormat; use proptest::collection::vec; use proptest::prelude::*; -use std::fmt; use std::fmt::Debug; -use std::marker::PhantomData; +use tests::display_error; use tract_data::internal::*; #[macro_export] macro_rules! mmm_packed_packed_tests { - ($cond:expr, $ker:ident, $packing_id:ident : $packing: expr, $ta:ty, $tb:ty, $tc:ty, $ti: ty) => { + ($cond:expr, $ker:ident, $packing_id:ident : $packing: expr) => { mod $packing_id { #[allow(unused_imports)] use super::$ker; - use num_traits::Zero; use proptest::prelude::*; #[allow(unused_imports)] use tract_data::prelude::f16; + use tract_data::prelude::*; + use tract_itertools::Itertools; #[allow(unused_imports)] use $crate::frame::mmm::tests::packed_packed::*; use $crate::frame::mmm::MatMatMulKer; - proptest::proptest! { - #[test] - fn packed_packed_prop(pb in any_with::>(($ker, $packing))) { + mod fuse { + use super::*; + + proptest::proptest! { + #[test] + fn prop(pb in any_with::>((false, $ker, $packing))) { + if $cond { + pb.check().unwrap() + } + } + } + + fn t(a: impl Into>, b: impl Into>) -> TractResult<()> { if $cond { - prop_assert_eq!(pb.run(), pb.reference()) + PackedPackedProblem::kernel($ker, $packing, a, b).check()?; } + Ok(()) } - } - #[test] - fn packed_packed_1() { - if $cond { - packed_packed::<_, $ta, $tb, $tc, $ti>($ker, $packing, 1) + #[test] + fn packed_packed_1() -> TractResult<()> { + t(vec![1f32; $ker.mr()], vec![1f32; $ker.nr()]) } - } - #[test] - fn packed_packed_2() { - if $cond { - packed_packed::<_, $ta, $tb, $tc, $ti>($ker, $packing, 2) + #[test] + fn packed_packed_2() -> TractResult<()> { + t(vec![1f32; $ker.mr() * 2], vec![1f32; $ker.nr() * 2]) } - } - #[test] - fn packed_packed_13() { - if $cond { - packed_packed::<_, $ta, $tb, $tc, $ti>($ker, $packing, 13) + #[test] + fn packed_packed_13() -> TractResult<()> { + t(vec![1f32; $ker.mr() * 13], vec![1f32; $ker.nr() * 13]) + } + + #[test] + fn packed_packed_a_scale() -> TractResult<()> { + t((1..=$ker.mr() as i64).map(|x| x as f32).collect_vec(), vec![1f32; $ker.nr()]) } - } - #[test] - fn packed_packed_empty() { - if $cond { - let pb = PackedPackedProblem::<_, $ta, $tb, $tc, $ti>::new( - $ker, - $packing, - 0, - vec![<$ta>::zero(); 0], - vec![<$tb>::zero(); 0], - false, - false, - ); - assert_eq!(pb.run(), pb.reference()) + #[test] + fn packed_packed_a_scale_times_2() -> TractResult<()> { + t( + (1..=2 * $ker.mr() as i64).map(|x| x as f32).collect_vec(), + vec![1f32; $ker.nr() * 2], + ) + } + + #[test] + fn packed_packed_empty() -> TractResult<()> { + t(vec![0f32; 0], vec![0f32; 0]) + } + + #[test] + fn packed_packed_bug_1() -> TractResult<()> { + t(vec![0f32; $ker.mr()], vec![0f32; $ker.nr()]) + } + + #[test] + fn packed_packed_bug_2() -> TractResult<()> { + let mut a = vec![0f32; $ker.mr()]; + a[0] = 1.; + let mut b = vec![0f32; $ker.nr()]; + b[0] = 1.; + t(a, b) + } + + #[test] + fn packed_packed_bug_3() -> TractResult<()> { + if $ker.mr() >= 4 { + let mut a = vec![0f32; 2 * $ker.mr()]; + let mut b = vec![0f32; 2 * $ker.nr()]; + a[2] = -0.7548828f32; + a[3] = 0.23547363f32; + b[2 * $ker.nr() - 1] = 0.93603516; + t(a, b)?; + } + Ok(()) } } - #[test] - fn packed_packed_bug_1() { - if $cond { - let pb = PackedPackedProblem::<_, $ta, $tb, $tc, $ti>::new( - $ker, - $packing, - 1, - vec![<$ta>::zero(); $ker.mr()], - vec![<$tb>::zero(); $ker.nr()], - true, - true, - ); - assert_eq!(pb.run(), pb.reference()) + mod frame { + use super::*; + + proptest::proptest! { + #[test] + fn prop(pb in any_with::>((true, $ker, $packing))) { + if $cond { + pb.check().unwrap() + } + } + } + + fn t( + m: usize, + n: usize, + a: impl Into>, + b: impl Into>, + ) -> TractResult<()> { + if $cond { + PackedPackedProblem::frame($ker, $packing, m, n, a, b).check()?; + } + Ok(()) + } + + fn ti( + m: usize, + n: usize, + a: impl Into>, + b: impl Into>, + ) -> TractResult<()> { + let a = a.into().into_iter().map(|i| i as f32).collect_vec(); + let b = b.into().into_iter().map(|i| i as f32).collect_vec(); + t(m, n, a, b) + } + + #[test] + fn trivial_1x2() -> TractResult<()> { + ti(1, 2, [0], [0, 0]) + } + + #[test] + fn mat_mul_1() -> TractResult<()> { + ti(3, 2, [-3, 3, 5, -5, 6, 0, -6, -5, 0, 0, 9, 7], [-8, 5, 5, -3, 5, 7, -8, -1]) + } + + #[test] + fn mat_mul_2() -> TractResult<()> { + ti(1, 3, [122, 82], [0, 0, 37, 0, 0, 57]) } } } @@ -89,154 +157,195 @@ macro_rules! mmm_packed_packed_tests { } #[derive(Debug, new)] -pub struct PackedPackedProblem +pub struct PackedPackedProblem where - K: MatMatMulKer, - TA: 'static + Debug + AsPrimitive, - TB: 'static + Debug + AsPrimitive, - TC: LADatum + Copy + PartialEq + 'static + Debug, - TI: LADatum + fmt::Display + AsPrimitive, - usize: AsPrimitive + AsPrimitive, + K: MatMatMulKer, { + pub frame_test: Option<(usize, usize)>, pub ker: K, pub packing: usize, - pub k: usize, - pub a: Vec, - pub b: Vec, - pub trans_c: bool, - pub add_one: bool, - pub _phantom: PhantomData<(K, TC, TI)>, + pub a: Vec, + pub b: Vec, } -impl Arbitrary for PackedPackedProblem +impl Arbitrary for PackedPackedProblem where - K: MatMatMulKer + Default + Copy, - TA: 'static + Debug + AsPrimitive, - TB: 'static + Debug + AsPrimitive, - TC: LADatum + Copy + PartialEq + 'static + Debug, - TI: LADatum + fmt::Display + AsPrimitive, - usize: AsPrimitive + AsPrimitive, + K: MatMatMulKer + Default + 'static, { - type Parameters = (K, usize); + type Parameters = (bool, K, usize); type Strategy = BoxedStrategy; - fn arbitrary_with((ker, packing): Self::Parameters) -> Self::Strategy { - (0usize..20, any::(), any::()) - .prop_flat_map(|(k, trans_c, add_one)| { - let ker = K::default(); - let m = k * ker.mr(); - let n = k * ker.nr(); - let a = (0usize..10).prop_map(|x| x.as_()); - let b = (0usize..10).prop_map(|x| x.as_()); - (Just(k), Just(trans_c), Just(add_one), vec(a, m..=m), vec(b, n..=n)) + fn arbitrary_with((frame_test, ker, packing): Self::Parameters) -> Self::Strategy { + let (mr, nr) = (ker.mr(), ker.nr()); + let item_range = + if ker.internal_type().is_integer() { (-5f32)..5f32 } else { (-1f32)..1f32 }; + let (m_range, n_range) = + if frame_test { (1usize..3 * mr, 1usize..3 * nr) } else { (mr..mr + 1, nr..nr + 1) }; + (m_range, 0usize..40, n_range) + .prop_flat_map(move |(m, k, n)| { + ( + vec(item_range.clone(), k * m..=k * m), + vec(item_range.clone(), k * n..=k * n), + Just((m, n)), + ) }) - .prop_map(move |(k, trans_c, add_one, a, b)| Self { + .prop_map(move |(a, b, mn)| Self { + frame_test: Some(mn).filter(|_| frame_test), ker, packing, - k, a, b, - trans_c, - add_one, - _phantom: PhantomData, }) .boxed() } } -impl PackedPackedProblem -where - K: MatMatMulKer, - TA: 'static + Debug + AsPrimitive + Datum, - TB: 'static + Debug + AsPrimitive + Datum, - TC: LADatum + Copy + Zero + PartialEq + 'static + Debug, - TI: LADatum + fmt::Display + AsPrimitive, - usize: AsPrimitive + AsPrimitive, -{ - pub fn reference(&self) -> Vec { - let init = if self.add_one { TI::one() } else { TI::zero() }; - let mr = self.ker.mr(); - let nr = self.ker.nr(); - let mut vi = vec![init; mr * nr]; - for m in 0..mr { - for n in 0..nr { - for k in 0..self.k { - let a: TI = self.a[m + mr * k].as_(); - let b: TI = self.b[n + nr * k].as_(); - let offset = if self.trans_c { m + n * mr } else { n + m * nr }; - vi[offset] += a * b; +impl PackedPackedProblem { + pub fn kernel( + ker: K, + packing: usize, + a: impl Into>, + b: impl Into>, + ) -> PackedPackedProblem { + PackedPackedProblem { frame_test: None, ker, packing, a: a.into(), b: b.into() } + } + + pub fn frame( + ker: K, + packing: usize, + m: usize, + n: usize, + a: impl Into>, + b: impl Into>, + ) -> PackedPackedProblem { + PackedPackedProblem { frame_test: Some((m, n)), ker, packing, a: a.into(), b: b.into() } + } + + pub fn mkn(&self) -> (usize, usize, usize) { + let (m, n) = self.frame_test.unwrap_or((self.ker.mr(), self.ker.nr())); + assert!(m != 0 && n != 0); + let k = self.a.len() / m; + assert_eq!(self.b.len() / n, k); + (m, k, n) + } + + pub fn padded_inputs(&self) -> TractResult<(Tensor, Tensor)> { + let (pack_a, pack_b) = self.ker.packings()[self.packing]; + assert!(pack_b.k_alignment() == 1); + let (m, k, n) = self.mkn(); + let k_aligned = k.next_multiple_of(pack_a.k_alignment()); + + let mut a = Tensor::zero::(&[m, k_aligned])?; + for row in 0..m { + for col in 0..k { + a.to_array_view_mut()?[[row, col]] = self.a[col + k * row]; + } + } + if let Some(pf) = pack_a.downcast_ref::() { + a = a.cast_to_dt(pf.dt)?.into_owned(); + } + let mut b = Tensor::zero::(&[k_aligned, n])?; + for row in 0..k { + for col in 0..n { + b.to_array_view_mut()?[[row, col]] = self.b[col + n * row]; + } + } + if let Some(pf) = pack_b.downcast_ref::() { + b = b.cast_to_dt(pf.dt)?.into_owned(); + } + + Ok((a, b)) + } + + pub fn reference(&self) -> TractResult { + let (m, k, n) = self.mkn(); + let pack_a = self.ker.packings()[self.packing].0; + let (mut a, b) = self.padded_inputs()?; + let k_aligned = k.next_multiple_of(pack_a.k_alignment()); + if let Some(pbqf) = pack_a.downcast_ref::() { + a = pbqf.simulate_precision_loss(a, 1)?; + }; + let mut c = Tensor::zero::(&[m, n])?; + + let a = a.cast_to::()?; + let a = a.as_slice::()?; + let b = b.cast_to::()?; + let b = b.as_slice::()?; + let mut view = c.to_array_view_mut::()?.into_dimensionality()?; + for ix_m in 0..m { + for ix_n in 0..n { + for ix_k in 0..k { + let a = a[ix_k + k_aligned * ix_m]; + let b = b[ix_n + n * ix_k]; + view[(ix_m, ix_n)] += a * b; } } } - vi.into_iter().map(|ti| ti.as_()).collect() + Ok(c) } - pub fn run(&self) -> Vec { - unsafe { - let packing = 0; - let pack_a = self.ker.packings()[packing].0.downcast_ref::().unwrap(); - let pack_b = self.ker.packings()[packing].1.downcast_ref::().unwrap(); - let a = self - .a - .iter() - .cloned() - .chain(vec![0.as_(); pack_a.end_padding_record * self.ker.mr()]) - .collect::>(); - let pa = Tensor::from_slice_align(&a, pack_a.alignment).unwrap(); - let b = self - .b - .iter() - .cloned() - .chain(vec![0.as_(); pack_b.end_padding_record * self.ker.nr()]) - .collect::>(); - let pb = Tensor::from_slice_align(&b, pack_b.alignment).unwrap(); - let mut v = vec![TC::zero(); self.ker.mr() * self.ker.nr()]; - let c = if self.trans_c { - mmm_stride_storage(&mut v, 1, self.ker.mr()) - } else { - mmm_stride_storage(&mut v, self.ker.nr(), 1) - }; - let b_store = pb.as_ptr_unchecked::() as _; - - let mut non_linear_ops = tvec!(FusedKerSpec::AddMatMul { - k: self.k, - pa: pa.as_ptr_unchecked::() as _, - pb: b_store, - packing: self.packing, - }); - if self.add_one { - non_linear_ops.push(FusedKerSpec::ScalarAdd(TI::one())); + pub fn run(&self) -> TractResult { + let (m, k, n) = self.mkn(); + let (pack_a, pack_b) = self.ker.packings()[self.packing]; + assert!(pack_b.k_alignment() == 1); + let k_aligned = k.next_multiple_of(pack_a.k_alignment()); + + let (a, b) = self.padded_inputs()?; + let pa = pack_a.prepare_tensor(&a, 1, 0)?; + let pb = pack_b.prepare_tensor(&b, 0, 1)?; + + let mut v = Tensor::zero_dt(self.ker.internal_type(), &[m, n])?; + let item_size = self.ker.internal_type().size_of(); + + if self.frame_test.is_some() { + unsafe { + let c = self.ker.c_view(0, 1).wrap(&v.view_mut()); + let ops = tvec!( + FusedSpec::AddMatMul { a: &*pa, b: &*pb, packing: self.packing }, + FusedSpec::Store(c) + ); + self.ker.run(m, n, &ops)?; } - non_linear_ops.push(FusedKerSpec::Store(c)); - non_linear_ops.push(FusedKerSpec::Done); - non_linear_ops.insert(0, FusedKerSpec::Clear); + } else { + let c = OutputStoreKer { + ptr: v.as_bytes_mut().as_mut_ptr(), + row_byte_stride: (item_size * self.ker.nr()) as isize, + col_byte_stride: item_size as isize, + item_size, + }; + + let non_linear_ops = tvec!( + FusedKerSpec::Clear, + FusedKerSpec::AddMatMul { + k: k_aligned, + pa: pa.panel_bytes(0, None)?, + pb: pb.panel_bytes(0, None)?, + packing: self.packing + }, + FusedKerSpec::Store(c), + FusedKerSpec::Done + ); let err = self.ker.kernel(&non_linear_ops); assert_eq!(err, 0); - v } + Ok(v) } -} -pub fn packed_packed(ker: K, packing: usize, k: usize) -where - K: MatMatMulKer, - TA: Copy + One + Datum + AsPrimitive, - TB: Copy + One + Datum + AsPrimitive, - TC: LADatum + Copy + PartialEq + Zero + 'static + Debug, - TI: LADatum + AsPrimitive, - usize: AsPrimitive + AsPrimitive + AsPrimitive, -{ - let a = vec![TA::one(); ker.mr() * k]; - let b = vec![TB::one(); ker.nr() * k]; - let pb = PackedPackedProblem::::new(ker, packing, k, a, b, false, false); - assert_eq!(pb.run(), pb.reference()) -} - -pub fn mmm_stride_storage(v: &mut [T], rsc: usize, csc: usize) -> OutputStoreKer { - OutputStoreKer { - ptr: v.as_mut_ptr() as _, - row_byte_stride: (std::mem::size_of::() * rsc) as isize, - col_byte_stride: (std::mem::size_of::() * csc) as isize, - item_size: std::mem::size_of::(), + pub fn check(&self) -> TractResult<()> { + let expected = self.reference()?; + let found = self.run()?; + let app = if K::Acc::datum_type() == f16::datum_type() { + Approximation::SuperApproximate + } else { + Approximation::Approximate + }; + let result = found.close_enough(&expected, app); + if result.is_err() { + let exp = expected.as_slice::()?; + let found = found.as_slice::()?; + let (m, _, n) = self.mkn(); + display_error(found, exp, m, n); + } + result } } diff --git a/linalg/src/generic/mmm.rs b/linalg/src/generic/mmm.rs index 0accacdd29..88a9641816 100644 --- a/linalg/src/generic/mmm.rs +++ b/linalg/src/generic/mmm.rs @@ -4,6 +4,7 @@ use num_traits::AsPrimitive; use tract_data::prelude::*; use super::*; +use crate::frame::block_quant::{BlockQuant, NibbleReader, PackedBlockQuantFormat, Q4_0}; use crate::frame::mmm::*; use crate::LADatum; @@ -60,6 +61,36 @@ unsafe fn add_mat_mul( } } +unsafe fn add_mat_mul_pq40( + pa: *const u8, + pb: *const u8, + k: usize, + ab: &mut [[TI; NR]; MR], +) where + TI: LADatum, + f16: AsPrimitive, + i8: AsPrimitive, +{ + assert!(k % Q4_0.block_len() == 0); + let len = (k * MR) / Q4_0.block_len() * Q4_0.block_bytes(); + let mut pa = NibbleReader::for_slice(std::slice::from_raw_parts(pa, len)); + let b = pb as *const TI; + for bk in 0..k / 32 { + let mut scales: [TI; MR] = [TI::zero(); MR]; + scales.iter_mut().for_each(|x| *x = pa.read_f16().as_()); + for ik in 0..32 { + let mut a: [TI; MR] = [TI::zero(); MR]; + a.iter_mut().zip(&scales).for_each(|(x, s)| *x = *s * (pa.read_i4() - 8).as_()); + let b = std::slice::from_raw_parts(b.add(NR * (ik + 32 * bk)), NR); + for i in 0..MR { + for j in 0..NR { + ab[i][j] += a[i] * b[j]; + } + } + } + } +} + unsafe fn store_t( tile: &OutputStoreKer, ab: &[[TI; NR]; MR], @@ -83,6 +114,8 @@ unsafe fn kernel(mut pnl: *const FusedKerS where TI: LADatum + ScaleShiftAndRound + AsPrimitive, usize: AsPrimitive, + i8: AsPrimitive, + f16: AsPrimitive, { unsafe { let mut ab = [[TI::zero(); NR]; MR]; @@ -171,8 +204,13 @@ where FusedKerSpec::AddMatMul { k, pa, pb, packing } => { use std::mem::transmute; if TI::datum_type().is_float() { - add_mat_mul::(pa, pb, k, &mut ab); + if packing == 0 { + add_mat_mul::(pa, pb, k, &mut ab); + } else if packing == 1 { + add_mat_mul_pq40(pa, pb, k, &mut ab); + } } else if TI::datum_type() == i32::datum_type() { + // transmute to allow using explicitly i3 in add_mat_mul generic params let ab = transmute::<&mut [[TI; NR]; MR], &mut [[i32; NR]; MR]>(&mut ab); if packing == 0 { add_mat_mul::(pa, pb, k, ab) @@ -199,30 +237,54 @@ where 0 } +const PQ40_R4: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 4); + MMMKernelWrapper!(f16, generic_f16_4x4; kernel::; 4, 4; 4, 4; 0, 0; no_prefetch, true); -MMMKernelWrapper!(f16, generic_f16_4x1; kernel::; 4, 1; 4, 4; 0, 0; no_prefetch, true); -MMMKernelWrapper!(f32, generic_f32_4x4; kernel::; 4, 4; 4, 4; 0, 0; no_prefetch, true); -MMMKernelWrapper!(f32, generic_f32_4x1; kernel::; 4, 1; 4, 4; 0, 0; no_prefetch, true); +MMMKernelWrapper!(f16, generic_f16_4x1; kernel::; 4, 1; 4, 4; 0, 0; no_prefetch, true, + packing_defs: { + const F16_B: PackedFormat = PackedFormat::new(DatumType::F16, 1, 4, 0); + const PQ40_F16: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&super::PQ40_R4, &F16_B); + }, + packings: PQ40_F16, + test: mmm_packed_packed_tests!{ true, generic_f16_4x1, q40f16:1 } +); + +MMMKernelWrapper!(f32, generic_f32_4x4; kernel::; 4, 4; 4, 4; 0, 0; no_prefetch, true, + packing_defs: { + const F32_B: PackedFormat = PackedFormat::new(DatumType::F32, 4, 4, 0); + const PQ40_F32: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&super::PQ40_R4, &F32_B); + }, + packings: PQ40_F32, + test: mmm_packed_packed_tests!{ true, generic_f32_4x4, q40f32:1 } +); +MMMKernelWrapper!(f32, generic_f32_4x1; kernel::; 4, 1; 4, 4; 0, 0; no_prefetch, true, + packing_defs: { + const F32_B: PackedFormat = PackedFormat::new(DatumType::F32, 1, 4, 0); + const PQ40_F32: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&super::PQ40_R4, &F32_B); + }, + packings: PQ40_F32, + test: mmm_packed_packed_tests!{ true, generic_f32_4x1, q40f32:1 } +); MMMKernelWrapper!(f64, generic_f64_4x4; kernel::; 4, 4; 4, 4; 0, 0; no_prefetch, true); MMMKernelWrapper!(f64, generic_f64_4x1; kernel::; 4, 1; 4, 4; 0, 0; no_prefetch, true); MMMKernelWrapper!(i32, generic_i32_4x4; kernel::; 4, 4; 4, 4; 0, 0; no_prefetch, true, packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 4, 4, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 4, 4, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 4, 4, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 4, 4, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, generic_i32_4x4, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, generic_i32_4x4, i8i8:1 } ); MMMKernelWrapper!(i32, generic_i32_4x1; kernel::; 4, 1; 4, 4; 0, 0; no_prefetch, true, packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 4, 4, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 1, 4, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 4, 4, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 1, 4, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, generic_i32_4x1, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, generic_i32_4x1, i8i8:1 } ); #[cfg(test)] @@ -231,10 +293,10 @@ MMMKernelWrapper!(f32, generic_f32_3x2; kernel::; 3, 2; 4, 4; 0, 0; n #[cfg(test)] MMMKernelWrapper!(i32, generic_i32_3x2; kernel::; 3, 2; 4, 4; 0, 0; no_prefetch, true, packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 3, 4, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 2, 4, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 3, 4, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 2, 4, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, generic_i32_3x2, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, generic_i32_3x2, i8i8:1 } ); diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index b2b775f7c1..ac9f08c5aa 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -1,6 +1,9 @@ +//use crate::frame::block_quant::{PackedBlockQuantFormat, Q4_0}; use crate::mmm::no_prefetch; use tract_data::prelude::f16; +// const PQ40_R32: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 32); + MMMExternKernel!(f16, fma_mmm_f16_8x8; 8, 8; 32, 2; 0, 0; no_prefetch, is_x86_feature_detected!("fma") && is_x86_feature_detected!("f16c")); MMMExternKernel!(f32, fma_mmm_f32_8x8; 8, 8; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); @@ -10,6 +13,19 @@ MMMExternKernel!(f32, fma_mmm_f32_24x4; 24, 4; 32, 4; 0, 0; no_prefetch, is_x86_ MMMExternKernel!(f32, fma_mmm_f32_32x3; 32, 3; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); MMMExternKernel!(f32, fma_mmm_f32_40x2; 40, 2; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); MMMExternKernel!(f32, fma_mmm_f32_64x1; 64, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); + +MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma")); +/* +MMMExternKernel!(f32, fma_mmm_f32_32x1; 32, 1; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("fma"), + packing_defs: { + const F32_B: PackedFormat = PackedFormat::new(DatumType::F32, 1, 4, 0); + const PQ40_F32: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&super::PQ40_R32, &F32_B); + }, + packings: PQ40_F32, + test: mmm_packed_packed_tests!{ is_x86_feature_detected!("fma"), fma_mmm_f32_32x1, q40f32:1 } +); +*/ + MMMExternKernel!(f32, avx512_mmm_f32_128x1; 128, 1; 64, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx512f")); MMMExternKernel!(f32, avx512_mmm_f32_16x1; 16, 1; 64, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx512f")); MMMExternKernel!(f32, avx512_mmm_f32_16x12; 16, 12; 64, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx512f")); @@ -22,10 +38,10 @@ MMMExternKernel!(f32, avx512_mmm_f32_80x2; 80, 2; 64, 4; 0, 0; no_prefetch, is_x MMMExternKernel!(i32, avx2_mmm_i32_8x8; 8, 8; 32, 4; 0, 0; no_prefetch, is_x86_feature_detected!("avx2"), packing_defs: { - const I8_A: Packer = Packer::new(DatumType::I8, 8, 32, 0); - const I8_B: Packer = Packer::new(DatumType::I8, 8, 4, 0); + const I8_A: PackedFormat = PackedFormat::new(DatumType::I8, 8, 32, 0); + const I8_B: PackedFormat = PackedFormat::new(DatumType::I8, 8, 4, 0); const I8_I8: (&dyn MMMInputFormat, &dyn MMMInputFormat) = (&I8_A, &I8_B); }, packings: I8_I8, - test: mmm_packed_packed_tests!{ true, avx2_mmm_i32_8x8, i8i8:1, i8, i8, i32, i32 } + test: mmm_packed_packed_tests!{ true, avx2_mmm_i32_8x8, i8i8:1 } ); diff --git a/linalg/tests/virtual_im2col.rs b/linalg/tests/virtual_im2col.rs index 5323bc3940..a0f9267727 100644 --- a/linalg/tests/virtual_im2col.rs +++ b/linalg/tests/virtual_im2col.rs @@ -7,8 +7,8 @@ use proptest::strategy::{BoxedStrategy, Strategy}; use tract_data::internal::*; use tract_linalg::frame::mmm::FusedSpec; // use tract_linalg::frame::mmm::{VirtualInput, VirtualInputSpec}; -use tract_linalg::frame::{Packer, PackingWriter}; -use tract_linalg::mmm::MMMInput; +use tract_linalg::frame::{PackedFormat, PackingWriter}; +use tract_linalg::mmm::MMMInputValue; use DatumType::F32; proptest::proptest! { @@ -154,16 +154,16 @@ impl ConvProblem { let (a_pack, b_pack) = mmm.packings()[0]; let a = a_pack.prepare_tensor(&reshaped_filters, 0, 1)?; unsafe { - let im2col: Box = if self.lazy_im2col { + let im2col: Box = if self.lazy_im2col { LazyIm2colSpec { full_kernel_shape: self.filters.shape().into(), - packer: b_pack.downcast_ref::().unwrap().clone(), + packer: b_pack.downcast_ref::().unwrap().clone(), } .wrap(&self.input.view()) } else { EagerIm2colSpec { full_kernel_shape: self.filters.shape().into(), - packer: b_pack.downcast_ref::().unwrap().clone(), + packer: b_pack.downcast_ref::().unwrap().clone(), } .wrap(&self.input.view()) }; @@ -216,12 +216,12 @@ fn tensor(shape: Vec) -> Tensor { #[derive(Clone, Debug, Hash)] struct EagerIm2colSpec { - packer: Packer, + packer: PackedFormat, full_kernel_shape: TVec, } impl EagerIm2colSpec { - fn wrap(&self, input: &TensorView) -> Box { + fn wrap(&self, input: &TensorView) -> Box { let (_, k, n, h, w) = mknhw(&self.full_kernel_shape, input.shape()); // let input = input.to_array_view::().unwrap(); let ci = input.shape()[0]; @@ -239,7 +239,7 @@ impl EagerIm2colSpec { #[derive(Clone, Debug, Hash)] struct EagerIm2col { - packer: Packer, + packer: PackedFormat, im2col: Tensor, k: usize, } @@ -250,7 +250,7 @@ impl Display for EagerIm2col { } } -impl MMMInput for EagerIm2col { +impl MMMInputValue for EagerIm2col { fn scratch_panel_buffer_layout(&self) -> Option { Some( Layout::from_size_align( @@ -261,7 +261,7 @@ impl MMMInput for EagerIm2col { ) } - fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> *const u8 { + fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> { let buffer = buffer.unwrap(); let mn = self.im2col.shape()[1]; unsafe { @@ -275,7 +275,7 @@ impl MMMInput for EagerIm2col { (i * self.packer.r)..((i + 1) * self.packer.r), ); } - buffer + Ok(buffer) } fn k(&self) -> usize { @@ -293,12 +293,12 @@ impl MMMInput for EagerIm2col { #[derive(Clone, Debug, Hash)] struct LazyIm2colSpec { - packer: Packer, + packer: PackedFormat, full_kernel_shape: TVec, } impl LazyIm2colSpec { - fn wrap(&self, input: &TensorView) -> Box { + fn wrap(&self, input: &TensorView) -> Box { let (_, _, _, h, w) = mknhw(&self.full_kernel_shape, input.shape()); let kh = self.full_kernel_shape[0]; let kw = self.full_kernel_shape[1]; @@ -331,7 +331,7 @@ impl LazyIm2colSpec { #[derive(Clone, Debug, Hash)] struct LazyIm2col { - packer: Packer, + packer: PackedFormat, image: *const f32, n_offsets: Vec, k_offsets: Vec, @@ -345,7 +345,7 @@ impl Display for LazyIm2col { } } -impl MMMInput for LazyIm2col { +impl MMMInputValue for LazyIm2col { fn scratch_panel_buffer_layout(&self) -> Option { Some( Layout::from_size_align( @@ -356,7 +356,7 @@ impl MMMInput for LazyIm2col { ) } - fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> *const u8 { + fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8> { let buffer = buffer.unwrap() as *mut f32; let mn_end = ((i + 1) * self.packer.r).min(self.n_offsets.len()); let n_range = (i * self.packer.r)..mn_end; @@ -373,7 +373,7 @@ impl MMMInput for LazyIm2col { } } } - buffer as _ + Ok(buffer as _) } fn k(&self) -> usize { diff --git a/linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl b/linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl new file mode 100644 index 0000000000..8cfac3b774 --- /dev/null +++ b/linalg/x86_64/fma/fma_mmm_f32_32x1.tmpl @@ -0,0 +1,155 @@ +{% comment %} +// vim: set syntax=asm : + +/* mmm 64 x 1 + + ymm0 + ymm1 + ymm2 + ymm3 + +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +{% endcomment %} + +{% include "preamble.tmpliq" type:"f32", size:"32x1", suffix:suffix, G:G %} + +{{L}}clear: + vzeroall + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov rcx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rbx, [rdi + 8] // k + mov r8, [rdi + 32] // packing + test rbx, rbx + jz {{L}}non_linear_loop + + cmp r8, 1 + jz {{L}}q40f32 + +{{align}} 16 +{{L}}main_loop_packed_packed: + vbroadcastss ymm15, dword ptr [rcx] + + vmovaps ymm8, [rax] + vmovaps ymm9, [rax + 32] + vmovaps ymm10, [rax + 64] + vmovaps ymm11, [rax + 96] + + vfmadd231ps ymm0, ymm15, ymm8 + vfmadd231ps ymm1, ymm15, ymm9 + vfmadd231ps ymm2, ymm15, ymm10 + vfmadd231ps ymm3, ymm15, ymm11 + + add rcx, 4 + add rax, 128 + sub rbx, 1 + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}q40f32: + // read scales:64 f16 scales + + jmp {{L}}unsupported + +{% include "fma_mmm_f32_scalars.tmpliq" from:0, to:3, type:"f32" %} +{% include "fma_mmm_f32_per_rows.tmpliq" mr:32, from:0, to:3, type:"f32" %} +{% include "fma_mmm_f32_per_cols.tmpliq" mr:32, from:0, to:3, type:"f32" %} + +{{L}}add_unicast: + mov r10, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + + cmp rsi, 4 + jne {{L}}add_unicast_generic + + {% for row in (0..3) %} + vaddps ymm{{row}}, ymm{{row}}, [ r10 + {{row|times:32}} ] + {% endfor %} + jmp {{L}}non_linear_loop + + + jmp {{L}}non_linear_loop + +{{L}}add_unicast_generic: + mov eax, 0 +{% for i in (0..3) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} +{% for i in (0..3) %} + pinsrd xmm15, eax, {{i}} + add eax, esi +{% endfor %} + + vperm2f128 ymm14, ymm14, ymm15, 32 // ymm14 <- xmm14::xmm15 + +{% for i in (0..3) %} + vpcmpeqd ymm15, ymm15, ymm15 + vgatherdps ymm12, [ r10 + ymm14 ], ymm15 + + vaddps ymm{{i}}, ymm{{i}}, ymm12 + lea r10, [ r10 + rsi * 8 ] +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + mov rax, [ rdi + 8 ] + mov rbx, [ rdi + 16 ] + + vbroadcastss ymm14, dword ptr [rbx] + +{% for i in (0..3) %} + vmovups ymm12, [rax + {{i|times:32}}] + vfmadd231ps ymm{{i}}, ymm12, ymm14 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + + cmp rsi, 4 + jne {{L}}store_generic + + {% for row in (0..3) %} + vmovups [r8 + {{row|times:32}}], ymm{{row}} + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}store_generic: + + {% for vec in (0..3) %} + {% for half in (0..1) %} + {% if half == 0 %} + movaps xmm9, xmm{{vec}} + {% else %} + vperm2f128 ymm9, ymm{{vec}}, ymm{{vec}}, 1 + {% endif %} + {% for row in (0..3) %} + vextractps dword ptr [r8], xmm9, {{row}} + add r8, rsi + {% endfor %} + {% endfor %} + {% endfor %} + + jmp {{L}}non_linear_loop + + +{% include "postamble.tmpliq" type:"f32", size:"64x1", suffix:suffix, G:G, L:L %}