Skip to content

Commit

Permalink
wip fixing zipping
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 22, 2024
1 parent d55a0d9 commit 28cde77
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 95 deletions.
6 changes: 1 addition & 5 deletions linalg/src/frame/block_quant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ use std::ops::Deref;

mod helpers;
mod q4_0;
mod repack;

pub use helpers::{NibbleReader, NibbleWriter};
pub use q4_0::Q4_0;
pub use repack::RepackingPackedBlockQuantValue;

use crate::mmm::{EagerPackedInput, MMMInputFormat};

Expand Down Expand Up @@ -94,15 +92,13 @@ pub trait BlockQuant: Debug + Display + Send + Sync + DynClone + DynHash + Downc

fn pack(&self, input: &[u8], k: usize, r: usize, zip: usize) -> TractResult<EagerPackedInput>;

/*
unsafe fn repack_panel(
unsafe fn extract_panel(
&self,
value: &EagerPackedInput,
target: &PackedFormat,
panel: usize,
scratch: *mut u8,
) -> TractResult<()>;
*/
}

dyn_clone::clone_trait_object!(BlockQuant);
Expand Down
95 changes: 51 additions & 44 deletions linalg/src/frame/block_quant/q4_0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ impl<const QK: usize> BaseQ4_0<QK> {
}
}

/*
unsafe fn repack_panel_t<T: Float + 'static>(
unsafe fn extract_panel_t<T: Float + Debug + 'static>(
&self,
value: &EagerPackedInput,
target: &PackedFormat,
Expand All @@ -60,27 +59,56 @@ impl<const QK: usize> BaseQ4_0<QK> {
f16: AsPrimitive<T>,
i8: AsPrimitive<T>,
{
ensure!(value.format.r() == target.r);
let pbqf: &PackedBlockQuantFormat = value.format.downcast_ref().with_context(|| {
format!("Expecing PackedBlockQuantFormat, found {:?}", value.format)
})?;
ensure!(pbqf.r == target.r);
ensure!(value.k % self.block_len() == 0);
ensure!(pbqf.bq.same_as(self));
let scratch = std::slice::from_raw_parts_mut(scratch as *mut T, value.k * target.r);
let blocks_for_k = value.k / self.block_len();
let row_bytes = blocks_for_k * self.block_bytes();
dbg!(&value);
dbg!(&value.packed);
dbg!(&value.packed[panel * target.r * row_bytes..]);
let mut input = NibbleReader::for_slice(&value.packed[panel * target.r * row_bytes..]);
let mut scales = vec![T::zero(); target.r];
let mut scratch = scratch.iter_mut();
let zipped_order = zipped_order(pbqf.r, pbqf.zip);
let mut weights = vec!(0i8; pbqf.r);
dbg!(pbqf);
dbg!(panel, blocks_for_k);
for _ in 0..blocks_for_k {
for s in &mut scales {
*s = input.read_f16().as_();
}
dbg!(&scales);
for _ in 0..self.block_len() {
for &s in &scales {
*scratch.next().unwrap() = s * (input.read_i4() - 8).as_();
for &o in &zipped_order {
weights[o] = input.read_i4();
}
for (w, s) in weights.iter().zip(scales.iter()) {
*scratch.next().unwrap() = *s * (*w - 8).as_();
}
}
}
Ok(())
}
*/
}

fn zipped_order(r: usize, zip: usize) -> Vec<usize> {
if zip == 0 {
(0..r).collect_vec()
} else {
(0..r)
.map(|i| {
let vec_pair_ix = i / (2 * zip);
let lane = (i % (2 * zip)) / 2;
let side = i % 2;
vec_pair_ix * 2 * zip + side * zip + lane
})
.collect_vec()
}
}

impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
Expand Down Expand Up @@ -135,18 +163,7 @@ impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
let mut blob =
unsafe { Blob::for_layout(Layout::from_size_align(panel_bytes * panels, 128)?) };
let mut writer = NibbleWriter::for_slice(&mut blob);
let order = if zip == 0 {
(0..r).collect_vec()
} else {
(0..r)
.map(|i| {
let vec_pair_ix = i / (2 * zip);
let lane = (i % (2 * zip)) / 2;
let side = i % 2;
vec_pair_ix * 2 * zip + side * zip + lane
})
.collect_vec()
};
let order = zipped_order(r, zip);
for p in 0..panels {
let input = &input[(r * p) * row_bytes..];
let mut readers = (0..r)
Expand Down Expand Up @@ -182,17 +199,15 @@ impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
})
}

/*
unsafe fn repack_panel(
unsafe fn extract_panel(
&self,
value: &EagerPackedInput,
target: &PackedFormat,
panel: usize,
scratch: *mut u8,
) -> TractResult<()> {
dispatch_floatlike!(Self::repack_panel_t(target.dt)(self, value, target, panel, scratch))
dispatch_floatlike!(Self::extract_panel_t(target.dt)(self, value, target, panel, scratch))
}
*/
}

impl<const QK: usize> Display for BaseQ4_0<QK> {
Expand Down Expand Up @@ -264,48 +279,40 @@ mod tests {
cycle_f16(Q4_0, &[-1234.0]);
}



#[test]
fn packing() -> TractResult<()> {
let (q, k, m, r) = (BaseQ4_0::<2>, 4, 4, 2);
test_packing(BaseQ4_0::<2>, 4, 4, 2, 0)
}

#[test]
fn packing_with_zip() -> TractResult<()> {
test_packing(BaseQ4_0::<2>, 2, 8, 8, 4)
}

fn test_packing(q: impl BlockQuant, k: usize, m: usize, r:usize, zip: usize) -> TractResult<()> {
let weights_orig =
Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
.into_tensor();
let weights_f32 =
q.dequant_f32(&q.quant_f32(weights_orig.as_slice::<f32>()?)?)?.into_shape(&[m, k])?;
eprintln!("{:?}", weights_f32.to_array_view::<f32>()?);
let packer = PackedFormat::new(f32::datum_type(), r, 128);
let packed_f32 = packer.pack_tensor(&weights_f32, 1, 0)?;
assert_eq!(packed_f32.panels_count(), 2);

let q4 = q.quant_f32(&weights_f32.as_slice::<f32>()?)?;
let packed_q4 = q.pack(&q4, k, r, 0)?;
let packed_q4 = q.pack(&q4, k, r, zip)?;

for panel in 0..2 {
for panel in 0..packed_f32.panels_count() {
unsafe {
let panel_f32 = packed_f32.panel_bytes(panel, None)?;
let panel_f32 = std::slice::from_raw_parts(panel_f32 as *const f32, k * r);
eprintln!("{panel_f32:?}");
let mut panel_q4 = Tensor::zero::<f32>(&[k * r])?;
// q.repack_panel(&packed_q4, &packer, panel, panel_q4.as_bytes_mut().as_mut_ptr())?;
eprintln!("{panel_q4:?}");
q.extract_panel(&packed_q4, &packer, panel, panel_q4.as_bytes_mut().as_mut_ptr())?;
assert_eq!(panel_q4.as_slice::<f32>()?, panel_f32);
}
}
Ok(())
}

#[test]
fn packing_with_zip() -> TractResult<()> {
let (q, k, m, r) = (BaseQ4_0::<2>, 2, 8, 8);
let weights_orig = Array2::from_shape_fn((m, k), |(m, _)| m as f32).into_tensor();
let weights_f32 =
q.dequant_f32(&q.quant_f32(weights_orig.as_slice::<f32>()?)?)?.into_shape(&[m, k])?;

let q4 = q.quant_f32(&weights_f32.as_slice::<f32>()?)?;
let packed_q4 = q.pack(&q4, k, r, 4)?;
unsafe {

}
Ok(())
}
}
46 changes: 0 additions & 46 deletions linalg/src/frame/block_quant/repack.rs

This file was deleted.

0 comments on commit 28cde77

Please sign in to comment.