From bafa9104c7e1baf944e4c075f0e3b4311c179cbe Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 23 Aug 2023 17:21:55 +0100 Subject: [PATCH 1/4] EfficientNet. --- candle-examples/examples/efficientnet/main.rs | 422 ++++++++++++++++++ candle-nn/src/var_builder.rs | 4 +- 2 files changed, 424 insertions(+), 2 deletions(-) create mode 100644 candle-examples/examples/efficientnet/main.rs diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs new file mode 100644 index 0000000000..432e1906aa --- /dev/null +++ b/candle-examples/examples/efficientnet/main.rs @@ -0,0 +1,422 @@ +//! EfficientNet implementation. +//! +//! https://arxiv.org/abs/1905.11946 + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn as nn; +use nn::{Module, VarBuilder}; + +// Based on the Python version from torchvision. +// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 +#[derive(Debug, Clone, Copy)] +pub struct MBConvConfig { + expand_ratio: f64, + kernel: usize, + stride: usize, + input_channels: usize, + out_channels: usize, + num_layers: usize, +} + +fn make_divisible(v: f64, divisor: usize) -> usize { + let min_value = divisor; + let new_v = usize::max( + min_value, + (v + divisor as f64 * 0.5) as usize / divisor * divisor, + ); + if (new_v as f64) < 0.9 * v { + new_v + divisor + } else { + new_v + } +} + +fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec { + let bneck_conf = |e, k, s, i, o, n| { + let input_channels = make_divisible(i as f64 * width_mult, 8); + let out_channels = make_divisible(o as f64 * width_mult, 8); + let num_layers = (n as f64 * depth_mult).ceil() as usize; + MBConvConfig { + expand_ratio: e, + kernel: k, + stride: s, + input_channels, + out_channels, + num_layers, + } + }; + vec![ + bneck_conf(1., 3, 1, 32, 16, 1), + bneck_conf(6., 3, 2, 16, 24, 2), + bneck_conf(6., 5, 2, 24, 40, 2), + bneck_conf(6., 3, 2, 40, 80, 3), + bneck_conf(6., 5, 1, 80, 112, 3), + bneck_conf(6., 5, 2, 112, 192, 4), + bneck_conf(6., 3, 1, 192, 320, 1), + ] +} + +impl MBConvConfig { + fn b0() -> Vec { + bneck_confs(1.0, 1.0) + } + fn b1() -> Vec { + bneck_confs(1.0, 1.1) + } + fn b2() -> Vec { + bneck_confs(1.1, 1.2) + } + fn b3() -> Vec { + bneck_confs(1.2, 1.4) + } + fn b4() -> Vec { + bneck_confs(1.4, 1.8) + } + fn b5() -> Vec { + bneck_confs(1.6, 2.2) + } + fn b6() -> Vec { + bneck_confs(1.8, 2.6) + } + fn b7() -> Vec { + bneck_confs(2.0, 3.1) + } +} + +/// Conv2D with same padding. +#[derive(Debug)] +struct Conv2DSame { + conv2d: nn::Conv2d, + s: usize, + k: usize, +} + +impl Conv2DSame { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + bias: bool, + ) -> Result { + let conv_config = nn::Conv2dConfig { + stride, + groups, + ..Default::default() + }; + let conv2d = if bias { + nn::conv2d(i, o, k, conv_config, vb)? + } else { + nn::conv2d_no_bias(i, o, k, conv_config, vb)? + }; + Ok(Self { + conv2d, + s: stride, + k, + }) + } +} + +impl Module for Conv2DSame { + fn forward(&self, xs: &Tensor) -> Result { + let s = self.s; + let k = self.k; + let (_, _, ih, iw) = xs.dims4()?; + let oh = (ih + s - 1) / s; + let ow = (iw + s - 1) / s; + let pad_h = usize::max((oh - 1) * s + k - ih, 0); + let pad_w = usize::max((ow - 1) * s + k - iw, 0); + if pad_h > 0 || pad_w > 0 { + let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; + let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; + self.conv2d.forward(&xs) + } else { + self.conv2d.forward(xs) + } + } +} + +#[derive(Debug)] +struct ConvNormActivation { + conv2d: Conv2DSame, + bn2d: nn::BatchNorm, + activation: bool, +} + +impl ConvNormActivation { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + ) -> Result { + let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; + let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; + Ok(Self { + conv2d, + bn2d, + activation: true, + }) + } + + fn no_activation(self) -> Self { + Self { + activation: false, + ..self + } + } +} + +impl Module for ConvNormActivation { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.conv2d.forward(xs)?; + let xs = self.bn2d.forward(&xs)?; + if self.activation { + swish(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct SqueezeExcitation { + fc1: Conv2DSame, + fc2: Conv2DSame, +} + +impl SqueezeExcitation { + fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result { + let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; + let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for SqueezeExcitation { + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + // xs.adaptive_avg_pool2d([1, 1]) + let xs = xs.avg_pool2d((1, 1), (1, 1))?; + let xs = self.fc1.forward(&xs)?; + let xs = swish(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = nn::ops::sigmoid(&xs)?; + residual * xs + } +} + +#[derive(Debug)] +struct MBConv { + expand_cna: Option, + depthwise_cna: ConvNormActivation, + squeeze_excitation: SqueezeExcitation, + project_cna: ConvNormActivation, + config: MBConvConfig, +} + +impl MBConv { + fn new(vb: VarBuilder, c: MBConvConfig) -> Result { + let vb = vb.pp("block"); + let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); + let expand_cna = if exp != c.input_channels { + Some(ConvNormActivation::new( + vb.pp("0"), + c.input_channels, + exp, + 1, + 1, + 1, + )?) + } else { + None + }; + let start_index = if expand_cna.is_some() { 1 } else { 0 }; + let depthwise_cna = + ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; + let squeeze_channels = usize::max(1, c.input_channels / 4); + let squeeze_excitation = + SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; + let project_cna = + ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? + .no_activation(); + Ok(Self { + expand_cna, + depthwise_cna, + squeeze_excitation, + project_cna, + config: c, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result { + let use_res_connect = + self.config.stride == 1 && self.config.input_channels == self.config.out_channels; + let ys = match &self.expand_cna { + Some(expand_cna) => expand_cna.forward(xs)?, + None => xs.clone(), + }; + let ys = self.depthwise_cna.forward(&ys)?; + let ys = self.squeeze_excitation.forward(&ys)?; + let ys = self.project_cna.forward(&ys)?; + if use_res_connect { + ys + xs + } else { + Ok(ys) + } + } +} + +fn swish(s: &Tensor) -> Result { + s * nn::ops::sigmoid(s)? +} + +#[derive(Debug)] +struct EfficientNet { + init_cna: ConvNormActivation, + blocks: Vec, + final_cna: ConvNormActivation, + classifier: nn::Linear, +} + +impl EfficientNet { + fn new(p: VarBuilder, configs: Vec, nclasses: usize) -> Result { + let f_p = p.pp("features"); + let first_in_c = configs[0].input_channels; + let last_out_c = configs.last().unwrap().out_channels; + let final_out_c = 4 * last_out_c; + let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; + let nconfigs = configs.len(); + let mut blocks = vec![]; + for (index, cnf) in configs.into_iter().enumerate() { + let f_p = f_p.pp(index + 1); + for r_index in 0..cnf.num_layers { + let cnf = if r_index == 0 { + cnf + } else { + MBConvConfig { + input_channels: cnf.out_channels, + stride: 1, + ..cnf + } + }; + blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) + } + } + let final_cna = + ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; + let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; + Ok(Self { + init_cna, + blocks, + final_cna, + classifier, + }) + } +} + +impl Module for EfficientNet { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = self.init_cna.forward(xs)?; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + let xs = self.final_cna.forward(&xs)?; + // TODO: xs.adaptive_avg_pool2d([1, 1])? + let xs = xs + .avg_pool2d((1, 1), (1, 1))? + .squeeze(D::Minus1)? + .squeeze(D::Minus1)?; + self.classifier.forward(&xs) + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + B0, + B1, + B2, + B3, + B4, + B5, + B6, + B7, +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Variant of the model to use. + #[arg(value_enum, long, default_value_t = Which::B2)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-dino-v2".into()); + api.get("dinov2_vits14.safetensors")? + } + Some(model) => model.into(), + }; + let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let cfg = match args.which { + Which::B0 => MBConvConfig::b0(), + Which::B1 => MBConvConfig::b1(), + Which::B2 => MBConvConfig::b2(), + Which::B3 => MBConvConfig::b3(), + Which::B4 => MBConvConfig::b4(), + Which::B5 => MBConvConfig::b5(), + Which::B6 => MBConvConfig::b6(), + Which::B7 => MBConvConfig::b7(), + }; + let model = EfficientNet::new(vb, cfg, candle_examples::imagenet::CLASS_COUNT as usize)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index b24ed56dfb..ef5b6fd1c6 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -129,7 +129,7 @@ impl<'a> VarBuilder<'a> { }) } - pub fn push_prefix(&self, s: &str) -> Self { + pub fn push_prefix(&self, s: S) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); Self { @@ -139,7 +139,7 @@ impl<'a> VarBuilder<'a> { } /// Short alias for `push_prefix`. - pub fn pp(&self, s: &str) -> Self { + pub fn pp(&self, s: S) -> Self { self.push_prefix(s) } From a378facc3bbe7f83427874452c8cb3fa6a1ac715 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 23 Aug 2023 17:30:15 +0100 Subject: [PATCH 2/4] Complete the efficientnet implementation. --- candle-examples/examples/efficientnet/main.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs index 432e1906aa..09ec764908 100644 --- a/candle-examples/examples/efficientnet/main.rs +++ b/candle-examples/examples/efficientnet/main.rs @@ -208,8 +208,8 @@ impl SqueezeExcitation { impl Module for SqueezeExcitation { fn forward(&self, xs: &Tensor) -> Result { let residual = xs; - // xs.adaptive_avg_pool2d([1, 1]) - let xs = xs.avg_pool2d((1, 1), (1, 1))?; + // equivalent to adaptive_avg_pool2d([1, 1]) + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; let xs = self.fc1.forward(&xs)?; let xs = swish(&xs)?; let xs = self.fc2.forward(&xs)?; @@ -336,11 +336,8 @@ impl Module for EfficientNet { xs = block.forward(&xs)? } let xs = self.final_cna.forward(&xs)?; - // TODO: xs.adaptive_avg_pool2d([1, 1])? - let xs = xs - .avg_pool2d((1, 1), (1, 1))? - .squeeze(D::Minus1)? - .squeeze(D::Minus1)?; + // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) + let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; self.classifier.forward(&xs) } } From c2dc18332c9ef8411fe10d11ebf2a08135d1a295 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 23 Aug 2023 17:39:36 +0100 Subject: [PATCH 3/4] Improve group handling. --- candle-nn/src/conv.rs | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 204402c382..5c53c8da43 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -124,7 +124,11 @@ pub fn conv1d( vs: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init((out_channels, in_channels, kernel_size), "weight", init_ws)?; + let ws = vs.get_or_init( + (out_channels, in_channels / cfg.groups, kernel_size), + "weight", + init_ws, + )?; let bound = 1. / (in_channels as f64).sqrt(); let init_bs = crate::Init::Uniform { lo: -bound, @@ -143,7 +147,12 @@ pub fn conv2d( ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vs.get_or_init( - (out_channels, in_channels, kernel_size, kernel_size), + ( + out_channels, + in_channels / cfg.groups, + kernel_size, + kernel_size, + ), "weight", init_ws, )?; @@ -165,7 +174,12 @@ pub fn conv2d_no_bias( ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vs.get_or_init( - (out_channels, in_channels, kernel_size, kernel_size), + ( + out_channels, + in_channels / cfg.groups, + kernel_size, + kernel_size, + ), "weight", init_ws, )?; From 505eabf078a35c445512ef4f67103c320fc22cf9 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 23 Aug 2023 17:57:39 +0100 Subject: [PATCH 4/4] Get the efficientnet to work. --- candle-core/src/conv.rs | 16 ++++++++++------ candle-examples/examples/efficientnet/main.rs | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index d4b7a76ddc..77d4c5cd1b 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -93,8 +93,8 @@ impl Tensor { let params = ParamsConv1D { b_size, l_in, - c_out, - c_in, + c_out: c_out / groups, + c_in: c_in / groups, k_size, padding, stride, @@ -103,9 +103,11 @@ impl Tensor { self.conv1d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() - .map(|block| block.conv1d_single_group(kernel, ¶ms)) + .zip(&kernel) + .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms)) .collect::>>()?; Tensor::cat(&blocks, 1) } @@ -146,8 +148,8 @@ impl Tensor { i_w, k_h, k_w, - c_out, - c_in, + c_out: c_out / groups, + c_in: c_in / groups, padding, stride, }; @@ -155,9 +157,11 @@ impl Tensor { self.conv2d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() - .map(|block| block.conv2d_single_group(kernel, ¶ms)) + .zip(&kernel) + .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms)) .collect::>>()?; Tensor::cat(&blocks, 1) } diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs index 09ec764908..fb6a5806e8 100644 --- a/candle-examples/examples/efficientnet/main.rs +++ b/candle-examples/examples/efficientnet/main.rs @@ -214,7 +214,7 @@ impl Module for SqueezeExcitation { let xs = swish(&xs)?; let xs = self.fc2.forward(&xs)?; let xs = nn::ops::sigmoid(&xs)?; - residual * xs + residual.broadcast_mul(&xs) } }