Skip to content

Commit

Permalink
Merge pull request #74 from charles-r-earp/candle-benches
Browse files Browse the repository at this point in the history
candle benches
  • Loading branch information
charles-r-earp committed Aug 18, 2024
2 parents 03395dc + fccf325 commit 1f5eed2
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:
cargo +nightly rustdoc --all-features -- --D warnings --cfg doc_cfg -A unexpected_cfgs
- name: msrv
run: |
cargo +nightly generate-lockfile -Zmsrv-policy --config "resolver.something-like-precedence='something-like-rust-version'"
cargo +nightly generate-lockfile -Zmsrv-policy --config "resolver.incompatible-rust-versions='fallback'"
cat Cargo.lock
cargo +1.70.0 check -p autograph -p neural-network-mnist-example --all-features --all-targets -v
cargo +1.70.0 check -p neural-network-benches --all-targets -v
Expand Down
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,21 @@ _NVIDIA GeForce GTX 1060 with Max-Q Design_

## LeNet5(training, batch_size = 100)

| | `autograph` | `tch` |
| :---------------- | :------------------------- | :------------------------------- |
| **`bf16_host`** | `482.80 ms` (✅ **1.00x**) | `75.30 ms` (🚀 **6.41x faster**) |
| **`f32_host`** | `5.44 ms` (✅ **1.00x**) | `3.09 ms` ( **1.76x faster**) |
| **`bf16_device`** | `1.76 ms` (✅ **1.00x**) | `17.99 ms` (❌ _10.20x slower_) |
| **`f32_device`** | `1.75 ms` (✅ **1.00x**) | `1.20 ms` (✅ **1.45x faster**) |
| | `autograph` | `tch` | `candle` |
|:------------------|:--------------------------|:---------------------------------|:-------------------------------- |
| **`bf16_host`** | `498.54 ms` (✅ **1.00x**) | `75.26 ms` (🚀 **6.62x faster**) | `N/A` |
| **`f32_host`** | `8.25 ms` (✅ **1.00x**) | `3.14 ms` (🚀 **2.63x faster**) | `34.17 ms` (❌ *4.14x slower*) |
| **`bf16_device`** | `1.76 ms` (✅ **1.00x**) | `17.63 ms` (❌ *10.02x slower*) | `N/A` |
| **`f32_device`** | `1.73 ms` (✅ **1.00x**) | `1.19 ms` (✅ **1.45x faster**) | `9.76 ms` (❌ *5.64x slower*) |

## LeNet5(inference, batch_size = 1,000)

| | `autograph` | `tch` |
| :---------------- | :------------------------ | :-------------------------------- |
| **`bf16_host`** | `1.78 s` (✅ **1.00x**) | `192.75 ms` (🚀 **9.25x faster**) |
| **`f32_host`** | `12.23 ms` (✅ **1.00x**) | `9.57 ms` (✅ **1.28x faster**) |
| **`bf16_device`** | `4.62 ms` (✅ **1.00x**) | `48.72 ms` (❌ _10.54x slower_) |
| **`f32_device`** | `4.76 ms` (✅ **1.00x**) | `1.84 ms` (🚀 **2.58x faster**) |
| | `autograph` | `tch` | `candle` |
|:------------------|:-------------------------|:---------------------------------|:-------------------------------- |
| **`bf16_host`** | `1.81 s` (✅ **1.00x**) | `193.60 ms` (🚀 **9.37x faster**) | `N/A` |
| **`f32_host`** | `15.56 ms` (✅ **1.00x**) | `9.46 ms` (✅ **1.64x faster**) | `94.23 ms` (❌ *6.06x slower*) |
| **`bf16_device`** | `4.65 ms` (✅ **1.00x**) | `48.63 ms` (❌ *10.46x slower*) | `N/A` |
| **`f32_device`** | `4.65 ms` (✅ **1.00x**) | `1.84 ms` (🚀 **2.52x faster**) | `10.81 ms` (❌ *2.33x slower*) |

See the [Neural Network](benches/neural-network-benches) benchmark.

Expand Down
5 changes: 4 additions & 1 deletion benches/neural-network-benches/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ tch = { version = "0.12.0", optional = true }
criterion = { version = "0.4.0", default-features = false }
anyhow = { workspace = true }
bytemuck = { workspace = true, optional = true }
candle-nn = { version = "0.6.0", optional = true }
candle-core = { version = "0.6.0", optional = true }

[dev-dependencies]
num-format.workspace = true

[features]
default = ["device"]
device = ["autograph/device"]
cuda = []
cuda = ["candle-nn?/cuda"]
tch = ["dep:tch", "dep:bytemuck"]
candle = ["dep:candle-nn", "dep:candle-core"]

[[bench]]
name = "benchmarks"
Expand Down
141 changes: 103 additions & 38 deletions benches/neural-network-benches/benches/benchmarks.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,51 @@
use autograph::krnl::{device::Device, scalar::ScalarType};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use neural_network_benches::autograph_backend;
#[cfg(feature = "candle")]
use neural_network_benches::candle_backend;
#[cfg(feature = "tch")]
use neural_network_benches::tch_backend;
use num_format::{Locale, ToFormattedString};
use std::str::FromStr;

fn autograph_devices(
#[cfg_attr(not(feature = "device"), allow(unused))] index: usize,
) -> impl IntoIterator<Item = Device> {
[
Device::host(),
#[cfg(feature = "device")]
Device::builder().index(index).build().unwrap(),
]
}

#[cfg(feature = "tch")]
fn tch_devices(
#[cfg_attr(not(feature = "cuda"), allow(unused))] index: usize,
) -> impl IntoIterator<Item = tch::Device> {
use tch::Device;

[
Device::Cpu,
#[cfg(feature = "cuda")]
Device::Cuda(index),
]
}

#[cfg(feature = "candle")]
fn candle_devices(
#[cfg_attr(not(feature = "cuda"), allow(unused))] index: usize,
) -> impl IntoIterator<Item = candle_core::Device> {
use candle_core::Device;
#[cfg(feature = "cuda")]
use candle_core::{backend::BackendDevice, CudaDevice};

[
Device::Cpu,
#[cfg(feature = "cuda")]
Device::Cuda(CudaDevice::new(index).unwrap()),
]
}

pub fn criterion_benchmark(c: &mut Criterion) {
let device_index = if cfg!(feature = "device") {
let krnl_device = std::env::var("KRNL_DEVICE");
Expand All @@ -20,8 +60,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
} else {
0
};

#[cfg_attr(not(feature = "cuda"), allow(unused))]
#[allow(unused)]
let cuda_device_index = if cfg!(feature = "cuda") {
let cuda_device = std::env::var("CUDA_DEVICE");
println!("CUDA_DEVICE = {cuda_device:?}");
Expand All @@ -35,7 +74,6 @@ pub fn criterion_benchmark(c: &mut Criterion) {
} else {
0
};

{
// training
let train_batch_size = 100;
Expand All @@ -45,15 +83,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
));
{
let scalar_types = [ScalarType::BF16, ScalarType::F32];
let devices = if cfg!(feature = "device") {
vec![
Device::host(),
Device::builder().index(device_index).build().unwrap(),
]
} else {
vec![Device::host()]
};
for device in devices {
for device in autograph_devices(device_index) {
let device_name = if device.is_device() { "device" } else { "host" };
for scalar_type in scalar_types {
let scalar_name = scalar_type.name();
Expand All @@ -73,15 +103,10 @@ pub fn criterion_benchmark(c: &mut Criterion) {
}
#[cfg(feature = "tch")]
{
use tch::{kind::Kind, Device};
use tch::kind::Kind;

let kinds = [Kind::BFloat16, Kind::Float];
let devices = if cfg!(feature = "cuda") {
vec![Device::Cpu, Device::Cuda(cuda_device_index)]
} else {
vec![Device::Cpu]
};
for device in devices {
for device in tch_devices(cuda_device_index) {
let device_name = if device.is_cuda() { "device" } else { "host" };
for kind in kinds {
let kind_name = match kind {
Expand All @@ -104,27 +129,45 @@ pub fn criterion_benchmark(c: &mut Criterion) {
}
}
}
}
#[cfg(feature = "candle")]
{
use candle_core::DType;

let dtypes = [/* Not Supported DType::BF16,*/ DType::F32];
for device in candle_devices(cuda_device_index) {
let device_name = if device.is_cuda() { "device" } else { "host" };
for dtype in dtypes {
let scalar_name = match dtype {
//DType::BF16 => "bf16",
DType::F32 => "f32",
_ => unreachable!(),
};
let name = format!("{scalar_name}_{device_name}");
let id = BenchmarkId::new("candle", name);
g.bench_function(id, |b| {
use candle_backend::LeNet5Classifier;
let mut model = LeNet5Classifier::new(device.clone(), dtype)
.unwrap()
.with_sgd(false)
.unwrap();
b.iter(|| {
model.train(train_batch_size).unwrap();
});
});
}
}
}
}
{
// inference
let infer_batch_size = 1000;
let mut g = c.benchmark_group(format!(
"LeNet5(inference, batch_size = {})",
infer_batch_size.to_formatted_string(&Locale::en)
));

{
let scalar_types = [ScalarType::BF16, ScalarType::F32];
let devices = if cfg!(feature = "device") {
vec![
Device::host(),
Device::builder().index(device_index).build().unwrap(),
]
} else {
vec![Device::host()]
};
for device in devices {
for device in autograph_devices(device_index) {
let device_name = if device.is_device() { "device" } else { "host" };
for scalar_type in scalar_types {
let scalar_name = scalar_type.name();
Expand All @@ -140,18 +183,12 @@ pub fn criterion_benchmark(c: &mut Criterion) {
}
}
}

#[cfg(feature = "tch")]
{
use tch::{kind::Kind, Device};
use tch::kind::Kind;

let kinds = [Kind::BFloat16, Kind::Float];
let devices = if cfg!(feature = "cuda") {
vec![Device::Cpu, Device::Cuda(cuda_device_index)]
} else {
vec![Device::Cpu]
};
for device in devices {
for device in tch_devices(cuda_device_index) {
let device_name = if device.is_cuda() { "device" } else { "host" };
for kind in kinds {
let kind_name = match kind {
Expand All @@ -171,6 +208,34 @@ pub fn criterion_benchmark(c: &mut Criterion) {
}
}
}
#[cfg(feature = "candle")]
{
use candle_core::DType;

let dtypes = [/* Not Supported DType::BF16,*/ DType::F32];
for device in candle_devices(cuda_device_index) {
let device_name = if device.is_cuda() { "device" } else { "host" };
for dtype in dtypes {
let scalar_name = match dtype {
//DType::BF16 => "bf16",
DType::F32 => "f32",
_ => unreachable!(),
};
let name = format!("{scalar_name}_{device_name}");
let id = BenchmarkId::new("candle", name);
g.bench_function(id, |b| {
use candle_backend::LeNet5Classifier;
let model = LeNet5Classifier::new(device.clone(), dtype)
.unwrap()
.with_sgd(false)
.unwrap();
b.iter(|| {
model.infer(infer_batch_size).unwrap();
});
});
}
}
}
}
if cfg!(all(feature = "device", feature = "tch")) {
eprintln!("warning: sig abort in torch on exit when vulkan is used");
Expand Down
2 changes: 1 addition & 1 deletion benches/neural-network-benches/src/autograph_backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::half::bf16;
use anyhow::Result;
use autograph::{
half::bf16,
krnl::{device::Device, scalar::ScalarType},
learn::{
criterion::CrossEntropyLoss,
Expand Down
101 changes: 101 additions & 0 deletions benches/neural-network-benches/src/candle_backend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use anyhow::Result;
use candle_core::{DType, Device, Tensor, Var};
use candle_nn::{
conv2d_no_bias, linear, linear_no_bias, loss::cross_entropy, Conv2d, Conv2dConfig, Linear,
Module, Optimizer, VarBuilder, VarMap, SGD,
};

pub struct LeNet5Classifier {
device: Device,
dtype: DType,
model: LeNet5,
optimizer: Option<SGD>,
varmap: VarMap,
_var_builder: VarBuilder<'static>,
}

impl LeNet5Classifier {
pub fn new(device: Device, dtype: DType) -> Result<Self> {
let varmap = VarMap::new();
let var_builder = VarBuilder::from_varmap(&varmap, dtype, &device);
let model = LeNet5::new(&var_builder)?;
Ok(Self {
device,
dtype,
model,
optimizer: None,
varmap,
_var_builder: var_builder,
})
}
pub fn with_sgd(self, momentum: bool) -> Result<Self> {
if momentum {
anyhow::bail!("Momentum not supported by candle!");
}
/*
let momentum = if momentum { 0.01 } else { 0.0 };
*/
let learning_rate = 0.01;
let optimizer = SGD::new(self.varmap.all_vars(), learning_rate)?;
Ok(Self {
optimizer: Some(optimizer),
..self
})
}
pub fn infer(&self, batch_size: usize) -> Result<()> {
let x = Tensor::zeros((batch_size, 1, 28, 28), self.dtype, &self.device)?;
let _y = self.model.forward(&x)?;
Ok(())
}
pub fn train(&mut self, batch_size: usize) -> Result<()> {
let x = Var::zeros((batch_size, 1, 28, 28), self.dtype, &self.device)?;
let t = Tensor::zeros(batch_size, DType::U32, &self.device)?;
let y = self.model.forward(&x)?;
let loss = cross_entropy(&y, &t)?;
self.optimizer.as_mut().unwrap().backward_step(&loss)?;
Ok(())
}
}

#[derive(Debug)]
struct LeNet5 {
conv1: Conv2d,
conv2: Conv2d,
dense1: Linear,
dense2: Linear,
dense3: Linear,
}

impl LeNet5 {
fn new(var_builder: &VarBuilder) -> Result<Self> {
let conv1 = conv2d_no_bias(1, 6, 5, Conv2dConfig::default(), var_builder.pp("conv1"))?;
let conv2 = conv2d_no_bias(6, 16, 5, Conv2dConfig::default(), var_builder.pp("conv2"))?;
let dense1 = linear_no_bias(16 * 4 * 4, 128, var_builder.pp("dense1"))?;
let dense2 = linear_no_bias(128, 84, var_builder.pp("dense2"))?;
let dense3 = linear(84, 10, var_builder.pp("dense3"))?;
Ok(Self {
conv1,
conv2,
dense1,
dense2,
dense3,
})
}
}

impl Module for LeNet5 {
fn forward(&self, xs: &Tensor) -> Result<Tensor, candle_core::error::Error> {
let Self {
conv1,
conv2,
dense1,
dense2,
dense3,
} = self;
let x = conv1.forward(xs)?.relu()?.max_pool2d(2)?;
let x = conv2.forward(&x)?.relu()?.max_pool2d(2)?.flatten_from(1)?;
let x = dense1.forward(&x)?.relu()?;
let x = dense2.forward(&x)?.relu()?;
dense3.forward(&x)
}
}
3 changes: 3 additions & 0 deletions benches/neural-network-benches/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use autograph::half;
pub mod autograph_backend;
#[cfg(feature = "candle")]
pub mod candle_backend;
#[cfg(feature = "tch")]
pub mod tch_backend;
Loading

0 comments on commit 1f5eed2

Please sign in to comment.