Skip to content

Commit

Permalink
Proptest Silu and RMS norm and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Jul 19, 2024
1 parent d488ab7 commit 87e13d2
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 190 deletions.
4 changes: 2 additions & 2 deletions metal/src/kernels/nn/nn_ops.metal
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct Min {
return simd_min(val);
}

static constexpr constant U init = metal::numeric_limits<U>::max();
static constexpr constant U init = metal::numeric_limits<U>::infinity();

// Operator
U operator()(U a, U b) {
Expand All @@ -60,7 +60,7 @@ struct Max {
return simd_max(val);
}

static constexpr constant U init = metal::numeric_limits<U>::lowest();
static constexpr constant U init = -metal::numeric_limits<U>::infinity();

// Operator
U operator()(U a, U b) {
Expand Down
355 changes: 168 additions & 187 deletions metal/src/kernels/nn/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,191 +99,172 @@ impl RmsNorm {

#[cfg(test)]
mod tests {
// use super::*;
// use crate::IntoMetal;
// use derive_new::new;
// use num_traits::AsPrimitive;
// use num_traits::Float;
// use proptest::collection::vec;
// use proptest::prelude::*;
// use tract_core::internal::Tensor;
// use tract_core::ops::nn::RmsNorm as TractRmsNorm;
// use tract_core::ops::nn::RmsNormExp;

// #[test]
// fn test_softmax_f32() -> Result<()> {
// objc::rc::autoreleasepool(|| {
// crate::METAL_CONTEXT.with_borrow(|context| {
// let m = 4;
// let k = 4;
// let axis = 1;

// let a =
// Tensor::from_shape(&[m, k], &(0..m * k).map(|f| f as f32).collect::<Vec<_>>())?
// .into_metal()?;

// let cpu_softmax = TractRmsNorm {
// axes: tvec![axis],
// quant_output_dt: None,
// exp: RmsNormExp::Libc,
// };

// let cpu_output =
// cpu_softmax.eval(tvec![a.to_cpu().into_tvalue()])?[0].clone().into_tensor();
// let metal_output = RmsNorm.eval(context, &a, axis)?;
// cpu_output.close_enough(&metal_output.to_cpu(), Approximation::Approximate)?;
// Ok(())
// })
// })
// }

// #[test]
// fn test_softmax_f32_2() -> Result<()> {
// objc::rc::autoreleasepool(|| {
// crate::METAL_CONTEXT.with_borrow(|context| {
// let shape = [8, 4, 3];
// let num_elements = shape.iter().product();
// let axis = 0;

// let a = Tensor::from_shape(
// &shape,
// &(0..num_elements).map(|f| f as f32 / 1000.0).collect::<Vec<_>>(),
// )?
// .into_metal()?;

// let cpu_softmax = TractRmsNorm {
// axes: tvec![axis],
// quant_output_dt: None,
// exp: RmsNormExp::Libc,
// };

// let cpu_output =
// cpu_softmax.eval(tvec![a.to_cpu().into_tvalue()])?[0].clone().into_tensor();
// let metal_output = RmsNorm.eval(context, &a, axis)?;
// cpu_output.close_enough(&metal_output.to_cpu(), Approximation::Approximate)?;
// Ok(())
// })
// })
// }

// #[test]
// fn test_softmax_f16() -> Result<()> {
// objc::rc::autoreleasepool(|| {
// crate::METAL_CONTEXT.with_borrow(|context| {
// let m = 4;
// let k = 4;
// let axis = 1;

// let a = Tensor::from_shape(
// &[m, k],
// &(0..m * k).map(|f| -> f16 { f.as_() }).collect::<Vec<_>>(),
// )?
// .into_metal()?;

// let cpu_softmax = TractRmsNorm {
// axes: tvec![axis],
// quant_output_dt: None,
// exp: RmsNormExp::Libc,
// };

// let cpu_output =
// cpu_softmax.eval(tvec![a.to_cpu().into_tvalue()])?[0].clone().into_tensor();
// let metal_output = RmsNorm.eval(context, &a, axis)?;
// cpu_output.close_enough(&metal_output.to_cpu(), Approximation::Approximate)?;
// Ok(())
// })
// })
// }

// proptest::proptest! {
// #[test]
// fn softmax_prop_f32(pb in any::<RmsNormProblem<f32>>()) {
// fn run(pb: RmsNormProblem<f32>) -> TractResult<()> {
// let out = pb.run()?;
// let reference = pb.reference()?;

// out.close_enough(&reference, Approximation::Approximate)
// .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
// }
// run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
// }

// #[test]
// fn softmax_prop_f16(pb in any::<RmsNormProblem<f16>>()) {
// fn run(pb: RmsNormProblem<f16>) -> TractResult<()> {
// let out = pb.run()?;
// let reference = pb.reference()?;

// out.close_enough(&reference, Approximation::Approximate)
// .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
// }

// run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
// }
// }

// #[derive(Debug, new)]
// pub struct RmsNormProblem<F: Datum + Float>
// where
// F: Datum + Float,
// usize: AsPrimitive<F>,
// {
// pub shape: Vec<usize>,
// pub axis: usize,
// pub input: Vec<F>,
// }

// impl<F> Arbitrary for RmsNormProblem<F>
// where
// F: Datum + Float,
// usize: AsPrimitive<F>,
// {
// type Parameters = ();
// type Strategy = BoxedStrategy<Self>;

// fn arbitrary_with(_: ()) -> Self::Strategy {
// (0usize..3, 0usize..3)
// .prop_flat_map(|(left, right)| {
// let axis = left;
// let shape_len = usize::min(left + right + 1, 4);
// let shape = 1usize..10;
// (vec(shape, shape_len..=shape_len), Just(axis))
// })
// .prop_map(|(shape, axis)| {
// let input = (0..shape.iter().product::<usize>())
// .map(|f| f.as_() / 1000.as_())
// .collect::<Vec<_>>();
// Self { shape, axis, input }
// })
// .boxed()
// }
// }

// impl<F> RmsNormProblem<F>
// where
// F: Datum + Float + std::ops::AddAssign,
// usize: AsPrimitive<F>,
// {
// pub fn reference(&self) -> Result<Tensor> {
// let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?;

// let cpu_softmax = TractRmsNorm {
// axes: tvec![self.axis],
// quant_output_dt: None,
// exp: RmsNormExp::Libc,
// };
// let cpu_output = cpu_softmax.eval(tvec![a.into_tvalue()])?[0].clone().into_tensor();
// Ok(cpu_output)
// }

// pub fn run(&self) -> Result<Tensor> {
// objc::rc::autoreleasepool(|| {
// crate::METAL_CONTEXT.with_borrow(|context| {
// let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?;
// let metal_output = RmsNorm.eval(context, &a, self.axis)?;
// Ok(metal_output.to_cpu())
// })
// })
// }
// }
use super::*;
use crate::IntoMetal;
use derive_new::new;
use num_traits::AsPrimitive;
use num_traits::Float;
use proptest::collection::vec;
use proptest::prelude::*;
use tract_core::internal::Tensor;
use crate::rewrite_rules::BasicRmsNorm;


fn test_case<F>(
shape: &[usize],
axis: usize,
offset: f32,
scale: f32,
) -> Result<()>
where
F: Float + Datum,
usize: AsPrimitive<f32>,
f32: AsPrimitive<F>,
{
objc::rc::autoreleasepool(|| {
crate::METAL_CONTEXT.with_borrow(|context| {
let len = shape.iter().product::<usize>();

let a = Tensor::from_shape(
shape,
&(0..len)
.map(|f| -> F {
let v: f32 = f.as_();
(v * scale + offset).as_()
})
.collect::<Vec<_>>(),
)?
.into_metal()?;

let eps = Arc::new(tensor0(0.0001f32.as_()));
let cpu_rms = BasicRmsNorm {
axis,
eps: Arc::clone(&eps),
};

let cpu_output =
cpu_rms.eval(tvec![a.to_cpu().into_tvalue()])?[0].clone().into_tensor();
let metal_output = RmsNorm.eval(context, &a, axis, &eps)?;

cpu_output
.close_enough(&metal_output.to_cpu(), Approximation::Approximate)
.with_context(|| {
anyhow!(
"Input: {:?}, scale: {:?} Cpu: {:?}, Metal: {:?}",
a.to_cpu().dump(true),
scale,
cpu_output.dump(true),
metal_output.to_cpu().dump(true)
)
})?;
Ok(())
})
})
}


#[test]
fn test_rms() -> Result<()> {
test_case::<f32>(&[4, 4], 1, -8.0, 1.0/100.0)?;
test_case::<f16>(&[4, 4], 1, -8.0, 1.0/100.0)?;
Ok(())
}

proptest::proptest! {
#[test]
fn rms_prop_f32(pb in any::<RmsNormProblem<f32>>()) {
fn run(pb: RmsNormProblem<f32>) -> TractResult<()> {
let out = pb.run()?;
let reference = pb.reference()?;

out.close_enough(&reference, Approximation::Approximate)
.with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
}
run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
}

#[test]
fn rms_prop_f16(pb in any::<RmsNormProblem<f16>>()) {
fn run(pb: RmsNormProblem<f16>) -> TractResult<()> {
let out = pb.run()?;
let reference = pb.reference()?;

out.close_enough(&reference, Approximation::Approximate)
.with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
}

run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
}
}

#[derive(Debug, new)]
pub struct RmsNormProblem<F: Datum + Float>
where
F: Datum + Float,
usize: AsPrimitive<F>,
f32: AsPrimitive<F>,
{
pub shape: Vec<usize>,
pub axis: usize,
pub input: Vec<F>,
pub eps: Arc<Tensor>,
}

impl<F> Arbitrary for RmsNormProblem<F>
where
F: Datum + Float,
usize: AsPrimitive<F>,
f32: AsPrimitive<F>,
{
type Parameters = ();
type Strategy = BoxedStrategy<Self>;

fn arbitrary_with(_: ()) -> Self::Strategy {
(0usize..3, 0usize..3)
.prop_flat_map(|(left, right)| {
let axis = left;
let shape_len = usize::min(left + right + 1, 4);
let shape = 1usize..10;
(vec(shape, shape_len..=shape_len), Just(axis))
})
.prop_map(|(shape, axis)| {
let input = (0..shape.iter().product::<usize>())
.map(|f| f.as_() / 1000.as_())
.collect::<Vec<_>>();
Self { shape, axis, input, eps: Arc::new(tensor0(0.0001f32.as_())) }
})
.boxed()
}
}

impl<F> RmsNormProblem<F>
where
F: Datum + Float + std::ops::AddAssign,
usize: AsPrimitive<F>,
f32: AsPrimitive<F>,
{
pub fn reference(&self) -> Result<Tensor> {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?;

let cpu_rms = BasicRmsNorm {
axis: self.axis,
eps: Arc::clone(&self.eps),
};

let cpu_output =
cpu_rms.eval(tvec![a.into_tvalue()])?[0].clone().into_tensor();

Ok(cpu_output)
}

pub fn run(&self) -> Result<Tensor> {
objc::rc::autoreleasepool(|| {
crate::METAL_CONTEXT.with_borrow(|context| {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?;
let metal_output = RmsNorm.eval(context, &a, self.axis, &self.eps)?;
Ok(metal_output.to_cpu())
})
})
}
}
}
Loading

0 comments on commit 87e13d2

Please sign in to comment.