Skip to content

Commit

Permalink
fix llms
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 20, 2024
1 parent 79239f8 commit 6646609
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions core/src/ops/matmul/pack.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::axes::Axis;
use crate::internal::*;
use crate::ops::matmul::de_block_quant::BlockQuantValue;
use ndarray::*;
use tract_data::TooEarly;
use tract_itertools::Itertools;
Expand Down Expand Up @@ -91,7 +92,7 @@ impl TypedOp for MatMatMulPack {
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.packers.iter().all(|p| *p == Packer::Identity) {
return TypedModelPatch::shunt_one_op(model, node)
return TypedModelPatch::shunt_one_op(model, node);
}
Ok(None)
}
Expand All @@ -110,7 +111,21 @@ impl MatMatMulPack {
bail!(TooEarly::Other("Undetermined scenario".into()))
};
if *packer == Packer::Identity {
return Ok(tvec!(input))
return Ok(tvec!(input));
}
if let Packer::PackBlockQuant(pbqf) = packer {
let value = input
.to_scalar::<Opaque>()?
.downcast_ref::<BlockQuantValue>()
.context("Expected a BlockQuant value")?;
ensure!(self.k_axis == 1);
ensure!(self.mn_axis == 0);
ensure!(pbqf.bq.same_as(&*value.fact.format));
let k = value.fact.shape[0].to_usize()?;
let packed = pbqf.pack(&value.value, k)?;
let mmm_input: Box<dyn MMMInputValue> = Box::new(packed);
let t = tensor0(Opaque::from(mmm_input));
return Ok(tvec!(t.into_tvalue()))
}
let output_shape: TVec<usize> = self.output_shape(input.shape());
let stores = if output_shape.iter().all(|d| *d == 1) {
Expand Down

0 comments on commit 6646609

Please sign in to comment.