diff --git a/linalg/activations/benches/vm.rs b/linalg/activations/benches/vm.rs index 14c226623b..e8527def60 100644 --- a/linalg/activations/benches/vm.rs +++ b/linalg/activations/benches/vm.rs @@ -1,8 +1,8 @@ -use activations::{definitions, reference}; -use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, BatchSize}; +use activations::{definitions, reference, Program}; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -pub fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("relu"); +fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program) { + let mut group = c.benchmark_group(name); for size in [1i32, 32, 256, 1024, 8192].iter() { group.throughput(criterion::Throughput::Elements(*size as u64)); group.bench_with_input(BenchmarkId::new("Reference", size), size, |b, size| { @@ -10,26 +10,32 @@ pub fn criterion_benchmark(c: &mut Criterion) { || vec![1.0f32; *size as usize], |v| { for x in v { - reference::relu(black_box(x)); + r(black_box(x)); } }, - BatchSize::LargeInput - ) + BatchSize::LargeInput, + ) }); - let d = definitions::relu(); group.bench_with_input(BenchmarkId::new("VM", size), size, |b, size| { b.iter_batched( || vec![1.0f32; *size as usize], |v| { for x in v { - d.compute(black_box(x)); + prog.compute(black_box(x)); } }, - BatchSize::LargeInput - ) + BatchSize::LargeInput, + ) }); } } +fn criterion_benchmark(c: &mut Criterion) { + crit(c, "relu", reference::relu, &definitions::relu()); + crit(c, "hardswish", reference::hardswish, &definitions::hardswish()); + crit(c, "exp2f", reference::exp2f, &definitions::exp2f()); + crit(c, "sigmoid", reference::sigmoid, &definitions::sigmoid()); +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches);