diff --git a/Cargo.lock b/Cargo.lock index 0e4b8086..c84f97c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -356,7 +356,7 @@ dependencies = [ [[package]] name = "candle-cublaslt" version = "0.2.2" -source = "git+https://github.com/huggingface/candle-cublaslt?rev=07e1a5490211e25ed0d096a2b21d3c607666eaae#07e1a5490211e25ed0d096a2b21d3c607666eaae" +source = "git+https://github.com/huggingface/candle-cublaslt?rev=ffd246552c266640fab217f964a83960e07a66ec#ffd246552c266640fab217f964a83960e07a66ec" dependencies = [ "candle-core", "cudarc", @@ -624,8 +624,8 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.9.14" -source = "git+https://github.com/OlivierDehaene/cudarc?rev=4c8e6d36a4a4c31e2e4649ae5246226452a01fc1#4c8e6d36a4a4c31e2e4649ae5246226452a01fc1" +version = "0.9.15" +source = "git+https://github.com/OlivierDehaene/cudarc?rev=8be6ff46e4a2014fb563570e0d206c09aea88152#8be6ff46e4a2014fb563570e0d206c09aea88152" dependencies = [ "half", ] diff --git a/Cargo.toml b/Cargo.toml index bea5e097..eccd5d12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-embeddings-inference" [patch.crates-io] -cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "4c8e6d36a4a4c31e2e4649ae5246226452a01fc1" } +cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "8be6ff46e4a2014fb563570e0d206c09aea88152" } candle = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-core" } candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-nn" } candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-transformers" } diff --git a/README.md b/README.md index 265285bc..5c3e0320 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ Options: If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures [env: DTYPE=] - [possible values: float16] + [possible values: float16, float32] --pooling Optionally control the pooling method. diff --git a/backends/Cargo.toml b/backends/Cargo.toml index 17c5106a..aaafec69 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -18,6 +18,7 @@ tracing = "^0.1" clap = ["dep:clap", "text-embeddings-backend-core/clap"] python = ["dep:text-embeddings-backend-python"] candle = ["dep:text-embeddings-backend-candle"] +cuda = ["text-embeddings-backend-candle?/cuda"] mkl = ["text-embeddings-backend-candle?/mkl"] mkl-dynamic = ["text-embeddings-backend-candle?/mkl-dynamic"] accelerate = ["text-embeddings-backend-candle?/accelerate"] diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index bc126f41..4d8862dd 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -13,7 +13,7 @@ candle-nn = { version = "0.3.0" } candle-transformers = { version = "0.3.0" } candle-flash-attn = { version = "0.3.0", optional = true } candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true } -candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "07e1a5490211e25ed0d096a2b21d3c607666eaae", optional = true } +candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "ffd246552c266640fab217f964a83960e07a66ec", optional = true } candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true } lazy_static = "^1.4" text-embeddings-backend-core = { path = "../core" } diff --git a/backends/candle/src/flash_attn.rs b/backends/candle/src/flash_attn.rs index 757b7f63..cf6189f0 100644 --- a/backends/candle/src/flash_attn.rs +++ b/backends/candle/src/flash_attn.rs @@ -1,7 +1,7 @@ use crate::compute_cap::RUNTIME_COMPUTE_CAP; use candle::Tensor; -#[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_arguments, unused)] pub(crate) fn flash_attn_varlen( q: &Tensor, k: &Tensor, diff --git a/backends/candle/src/layers.rs b/backends/candle/src/layers.rs new file mode 100644 index 00000000..f0bf7e60 --- /dev/null +++ b/backends/candle/src/layers.rs @@ -0,0 +1,8 @@ +#[allow(dead_code, unused)] +mod cublaslt; +mod layer_norm; +mod linear; + +pub use cublaslt::CUBLASLT; +pub use layer_norm::LayerNorm; +pub use linear::{HiddenAct, Linear}; diff --git a/backends/candle/src/layers/cublaslt.rs b/backends/candle/src/layers/cublaslt.rs new file mode 100644 index 00000000..08107c99 --- /dev/null +++ b/backends/candle/src/layers/cublaslt.rs @@ -0,0 +1,104 @@ +use crate::layers::HiddenAct; +use candle::{Device, Result, Tensor}; +use lazy_static::lazy_static; + +#[cfg(feature = "cuda")] +use candle_cublaslt::{fused_batch_matmul, fused_matmul, Activation, CublasLt}; + +lazy_static! { + pub static ref CUBLASLT: Option = { + match Device::cuda_if_available(0) { + Ok(device) => { + #[cfg(feature = "cuda")] + { + Some(CublasLtWrapper { + cublaslt: CublasLt::new(&device).unwrap(), + }) + } + #[cfg(not(feature = "cuda"))] + { + None + } + } + Err(_) => None, + } + }; +} + +#[derive(Debug, Clone)] +pub struct CublasLtWrapper { + #[cfg(feature = "cuda")] + pub cublaslt: CublasLt, +} + +impl CublasLtWrapper { + #[allow(clippy::too_many_arguments)] + pub fn matmul( + &self, + a: &Tensor, + b: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + ) -> Result { + #[cfg(feature = "cuda")] + { + let act = act.clone().map(|a| match a { + HiddenAct::Gelu => Activation::Gelu, + HiddenAct::Relu => Activation::Relu, + }); + + fused_matmul( + &a, + &b, + out, + alpha, + beta, + bias, + act.clone(), + self.cublaslt.clone(), + ) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } + + #[allow(clippy::too_many_arguments)] + pub fn batch_matmul( + &self, + a: &Tensor, + b: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + ) -> Result { + #[cfg(feature = "cuda")] + { + let act = act.clone().map(|a| match a { + HiddenAct::Gelu => Activation::Gelu, + HiddenAct::Relu => Activation::Relu, + }); + + fused_batch_matmul( + &a, + &b, + out, + alpha, + beta, + bias, + act.clone(), + self.cublaslt.clone(), + ) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } +} diff --git a/backends/candle/src/layers/layer_norm.rs b/backends/candle/src/layers/layer_norm.rs new file mode 100644 index 00000000..c1e2b7f6 --- /dev/null +++ b/backends/candle/src/layers/layer_norm.rs @@ -0,0 +1,74 @@ +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct LayerNorm { + weight: Tensor, + bias: Tensor, + epsilon: f32, + span: tracing::Span, +} + +impl LayerNorm { + pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result { + Ok(Self { + weight: vb + .get(hidden_size, "weight") + .or_else(|_| vb.get(hidden_size, "gamma"))?, + bias: vb + .get(hidden_size, "bias") + .or_else(|_| vb.get(hidden_size, "beta"))?, + epsilon, + span: tracing::span!(tracing::Level::TRACE, "layer-norm"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result { + let _enter = self.span.enter(); + + match hidden_states.device() { + Device::Cpu => { + let hidden_states = hidden_states.add(residual)?; + let hidden_states_dtype = hidden_states.dtype(); + let internal_dtype = match hidden_states_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = hidden_states.dim(D::Minus1)?; + let hidden_states = hidden_states.to_dtype(internal_dtype)?; + let mean_hidden_states = + (hidden_states.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let hidden_states = hidden_states.broadcast_sub(&mean_hidden_states)?; + let norm_hidden_states = + (hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let hidden_states_normed = hidden_states + .broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?; + let hidden_states = hidden_states_normed + .to_dtype(hidden_states_dtype)? + .broadcast_mul(&self.weight)?; + hidden_states.broadcast_add(&self.bias) + } + Device::Cuda(_) => { + #[cfg(feature = "cuda")] + { + use candle_layer_norm::fused_add_layer_norm; + + let original_shape = hidden_states.shape(); + let hidden_states = hidden_states.flatten_to(D::Minus2)?; + let residual = residual.flatten_to(D::Minus2)?; + + let result = fused_add_layer_norm( + &hidden_states, + &residual, + &self.weight, + &self.bias, + self.epsilon, + )?; + result.reshape(original_shape) + } + #[cfg(not(feature = "cuda"))] + candle::bail!("`cuda` feature is not enabled") + } + } + } +} diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs new file mode 100644 index 00000000..19ce2dc5 --- /dev/null +++ b/backends/candle/src/layers/linear.rs @@ -0,0 +1,73 @@ +use crate::layers::CUBLASLT; +use candle::{Device, Result, Tensor, D}; +use serde::Deserialize; + +#[derive(Debug, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum HiddenAct { + Gelu, + Relu, +} + +#[derive(Debug)] +pub struct Linear { + weight: Tensor, + bias: Option, + act: Option, + span: tracing::Span, +} + +impl Linear { + pub fn new(weight: Tensor, bias: Option, act: Option) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + + Self { + weight, + bias, + act, + span, + } + } + + pub fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + + #[allow(unused)] + if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), &*CUBLASLT) { + // fused matmul requires x to be dims2 + let mut final_shape = x.dims().to_vec(); + final_shape.pop(); + final_shape.push(self.weight.dims()[0]); + + let x = x.flatten_to(D::Minus2)?; + let result = cublaslt.matmul( + &self.weight, + &x, + None, + None, + None, + self.bias.as_ref(), + self.act.clone(), + )?; + result.reshape(final_shape) + } else { + let w = match x.dims() { + &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = x.matmul(&w)?; + let x = match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + }?; + if let Some(act) = &self.act { + match act { + HiddenAct::Gelu => x.gelu(), + HiddenAct::Relu => x.relu(), + } + } else { + Ok(x) + } + } + } +} diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 6525c74e..8ba3f821 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -2,11 +2,14 @@ mod compute_cap; #[cfg(feature = "cuda")] mod flash_attn; +mod layers; mod models; #[cfg(feature = "cuda")] use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_COMPUTE_CAP}; -use crate::models::{BertModel, EmbeddingModel, QuantBertModel}; +#[cfg(feature = "cuda")] +use crate::models::FlashBertModel; +use crate::models::{BertModel, EmbeddingModel, PositionEmbeddingType, QuantBertModel}; use candle::{DType, Device}; use candle_nn::VarBuilder; use models::Config; @@ -88,8 +91,6 @@ impl CandleBackend { )); #[cfg(feature = "cuda")] { - use crate::models::FlashBertModel; - // Get candle dtype let dtype = if &dtype == "float32" { Ok(DType::F32) @@ -119,8 +120,19 @@ impl CandleBackend { return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", *RUNTIME_COMPUTE_CAP, *COMPILE_COMPUTE_CAP))); } - tracing::info!("Starting FlashBert model on Cuda"); - Box::new(FlashBertModel::load(vb, &config, pool).s()?) + if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) + && dtype == DType::F16 + && config.position_embedding_type == PositionEmbeddingType::Absolute + // Flash attention v1 precision problem with head_size == 32 + // See: https://github.com/huggingface/text-embeddings-inference/issues/37 + && !(*RUNTIME_COMPUTE_CAP == 75 && (config.hidden_size / config.num_attention_heads) == 32) + { + tracing::info!("Starting FlashBert model on Cuda"); + Box::new(FlashBertModel::load(vb, &config, pool).s()?) + } else { + tracing::info!("Starting Bert model on Cuda"); + Box::new(BertModel::load(vb, &config, pool).s()?) + } } } }; diff --git a/backends/candle/src/models.rs b/backends/candle/src/models.rs index 7f9a7eab..ffbc83b8 100644 --- a/backends/candle/src/models.rs +++ b/backends/candle/src/models.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; mod bert; mod bert_quant; -pub use bert::{BertModel, Config}; +pub use bert::{BertModel, Config, PositionEmbeddingType}; pub use bert_quant::QuantBertModel; use candle::{Result, Tensor}; use text_embeddings_backend_core::Batch; diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index 61ed0370..7f6c82ac 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -1,3 +1,4 @@ +use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT}; use crate::models::EmbeddingModel; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; @@ -29,13 +30,6 @@ pub struct Config { pub id2label: Option>, } -#[derive(Debug, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] -pub enum HiddenAct { - Gelu, - Relu, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] pub enum PositionEmbeddingType { @@ -43,90 +37,6 @@ pub enum PositionEmbeddingType { Absolute, } -#[derive(Debug)] -struct LayerNorm { - weight: Tensor, - bias: Tensor, - epsilon: f64, - span: tracing::Span, -} - -impl LayerNorm { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - Ok(Self { - weight: vb - .get(config.hidden_size, "weight") - .or_else(|_| vb.get(config.hidden_size, "gamma"))?, - bias: vb - .get(config.hidden_size, "bias") - .or_else(|_| vb.get(config.hidden_size, "beta"))?, - epsilon: config.layer_norm_eps, - span: tracing::span!(tracing::Level::TRACE, "layer-norm"), - }) - } - - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - - let x_dtype = x.dtype(); - let internal_dtype = match x_dtype { - DType::F16 | DType::BF16 => DType::F32, - d => d, - }; - let hidden_size = x.dim(D::Minus1)?; - let x = x.to_dtype(internal_dtype)?; - let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.epsilon)?.sqrt()?)?; - let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; - x.broadcast_add(&self.bias) - } -} - -#[derive(Debug)] -pub struct Linear { - weight: Tensor, - bias: Option, - act: Option, - span: tracing::Span, -} - -impl Linear { - pub fn new(weight: Tensor, bias: Option, act: Option) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - - Self { - weight, - bias, - act, - span, - } - } - - pub fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, - _ => self.weight.t()?, - }; - let x = x.matmul(&w)?; - let x = match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - }?; - if let Some(act) = &self.act { - match act { - HiddenAct::Gelu => x.gelu(), - HiddenAct::Relu => x.relu(), - } - } else { - Ok(x) - } - } -} - #[derive(Debug)] struct BertEmbeddings { word_embeddings: Embedding, @@ -139,7 +49,7 @@ struct BertEmbeddings { impl BertEmbeddings { pub fn load(vb: VarBuilder, config: &Config) -> Result { if config.position_embedding_type != PositionEmbeddingType::Absolute { - candle::bail!("FlashBert only supports absolute position embeddings"); + candle::bail!("Bert only supports absolute position embeddings"); } Ok(Self { @@ -160,7 +70,11 @@ impl BertEmbeddings { )?, config.hidden_size, ), - layer_norm: LayerNorm::load(vb.pp("LayerNorm"), config)?, + layer_norm: LayerNorm::load( + vb.pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?, span: tracing::span!(tracing::Level::TRACE, "embeddings"), }) } @@ -177,10 +91,8 @@ impl BertEmbeddings { let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; let position_embeddings = self.position_embeddings.forward(position_ids)?; - let embeddings = input_embeddings - .add(&token_type_embeddings)? - .add(&position_embeddings)?; - let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = input_embeddings.add(&token_type_embeddings)?; + let embeddings = self.layer_norm.forward(&embeddings, &position_embeddings)?; Ok(embeddings) } @@ -233,7 +145,11 @@ impl BertAttention { let dense = Linear::new(dense_weight, Some(dense_bias), None); - let layer_norm = LayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?; + let layer_norm = LayerNorm::load( + vb.pp("output").pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; let softmax_scale = 1. / (attention_head_size as f64).sqrt(); @@ -250,6 +166,7 @@ impl BertAttention { fn forward(&self, hidden_states: &Tensor, attention_mask: Option<&Tensor>) -> Result { let _enter = self.span.enter(); + let device = hidden_states.device(); let residual = hidden_states.clone(); @@ -264,22 +181,78 @@ impl BertAttention { let qkv = qkv.chunk(3, 1)?; let query_layer = &qkv[0].contiguous()?; let key_layer = &qkv[1].contiguous()?; - let value_layer = &qkv[2].contiguous()?; - - let attention_scores = query_layer.matmul(&key_layer.t()?)?; - let mut attention_scores = (attention_scores * self.softmax_scale)?; + let value_layer = &qkv[2]; + + #[allow(unused_variables)] + let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = (device, &*CUBLASLT) { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?; + let key_layer = key_layer.flatten(0, 1)?; + let query_layer = query_layer.flatten(0, 1)?; + let value_layer = value_layer.flatten(0, 1)?; + let attention_mask = attention_mask.map(|mask| mask.flatten(0, 1)).transpose()?; + + // If attention_mask is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_mask.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_mask add + let attention_scores = cublaslt.batch_matmul( + &key_layer, + &query_layer, + attention_mask.as_ref(), + Some(self.softmax_scale as f32), + beta, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &value_layer.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&query_layer), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let mut attention_scores = (attention_scores * self.softmax_scale)?; - if let Some(attention_mask) = attention_mask { - attention_scores = attention_scores.add(attention_mask)?; - } + if let Some(attention_mask) = attention_mask { + attention_scores = attention_scores.add(attention_mask)?; + } - let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs.matmul(&value_layer.contiguous()?) + }?; - let context_layer = attention_probs.matmul(value_layer)?; let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; - let hidden_states = self.dense.forward(&context_layer)?.add(&residual)?; - let hidden_states = self.layer_norm.forward(&hidden_states)?; + let hidden_states = self.dense.forward(&context_layer)?; + let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; Ok(hidden_states) } @@ -321,7 +294,11 @@ impl BertLayer { .get(config.hidden_size, "bias")?; let output = Linear::new(output_weight, Some(output_bias), None); - let layer_norm = LayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?; + let layer_norm = LayerNorm::load( + vb.pp("output").pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; Ok(Self { attention, @@ -343,8 +320,8 @@ impl BertLayer { let residual = hidden_states.clone(); let hidden_states = self.intermediate.forward(&hidden_states)?; - let hidden_states = self.output.forward(&hidden_states)?.add(&residual)?; - let hidden_states = self.layer_norm.forward(&hidden_states)?; + let hidden_states = self.output.forward(&hidden_states)?; + let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; Ok(hidden_states) } @@ -386,21 +363,17 @@ pub struct BertModel { num_attention_heads: usize, - pub device: Device, + device: Device, + dtype: DType, span: tracing::Span, } impl BertModel { pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result { - match vb.device() { - Device::Cpu => {} - _ => candle::bail!("Bert requires CPU"), - } - // Check position embedding type if config.position_embedding_type != PositionEmbeddingType::Absolute { - candle::bail!("FlashBert only supports absolute position embeddings") + candle::bail!("Bert only supports absolute position embeddings") } // Check pool type @@ -438,6 +411,7 @@ impl BertModel { pool, num_attention_heads: config.num_attention_heads, device: vb.device().clone(), + dtype: vb.dtype(), span: tracing::span!(tracing::Level::TRACE, "model"), }) } @@ -496,7 +470,8 @@ impl BertModel { attention_mask, (batch_size, 1, 1, max_length), &self.device, - )?; + )? + .to_dtype(self.dtype)?; // Broadcast once instead of at every layer let attention_mask = attention_mask .broadcast_as(( @@ -532,7 +507,8 @@ impl BertModel { let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; let type_ids = Tensor::from_vec(type_ids, shape, &self.device)?; let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?; - let input_lengths = Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?; + let input_lengths = + Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; let embedding_output = self .embeddings diff --git a/backends/candle/src/models/bert_quant.rs b/backends/candle/src/models/bert_quant.rs index ca63790c..d0ff094c 100644 --- a/backends/candle/src/models/bert_quant.rs +++ b/backends/candle/src/models/bert_quant.rs @@ -1,4 +1,5 @@ -use crate::models::bert::{Config, HiddenAct, PositionEmbeddingType}; +use crate::layers::HiddenAct; +use crate::models::bert::{Config, PositionEmbeddingType}; use crate::models::EmbeddingModel; use candle::quantized::QMatMul; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs index 7bdb7ba6..7dda1298 100644 --- a/backends/candle/src/models/flash_bert.rs +++ b/backends/candle/src/models/flash_bert.rs @@ -1,98 +1,17 @@ use crate::flash_attn::flash_attn_varlen; -use crate::models::bert::{Config, HiddenAct, PositionEmbeddingType}; +use crate::layers::{LayerNorm, Linear}; +use crate::models::bert::{Config, PositionEmbeddingType}; use crate::models::EmbeddingModel; use candle::{DType, Device, Result, Tensor}; -use candle_cublaslt::{fused_matmul, Activation, CublasLt}; -use candle_layer_norm::fused_add_layer_norm; use candle_nn::{Embedding, Module, VarBuilder}; use text_embeddings_backend_core::{Batch, Pool}; -#[derive(Debug)] -struct FastLayerNorm { - weight: Tensor, - bias: Tensor, - epsilon: f32, - span: tracing::Span, -} - -impl FastLayerNorm { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - Ok(Self { - weight: vb - .get(config.hidden_size, "weight") - .or_else(|_| vb.get(config.hidden_size, "gamma"))?, - bias: vb - .get(config.hidden_size, "bias") - .or_else(|_| vb.get(config.hidden_size, "beta"))?, - epsilon: config.layer_norm_eps as f32, - span: tracing::span!(tracing::Level::TRACE, "layer-norm"), - }) - } - - pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result { - let _enter = self.span.enter(); - - fused_add_layer_norm( - hidden_states, - residual, - &self.weight, - &self.bias, - self.epsilon, - ) - } -} - -#[derive(Debug)] -pub struct Linear { - weight: Tensor, - bias: Option, - act: Option, - cublaslt: CublasLt, - span: tracing::Span, -} - -impl Linear { - pub fn new( - weight: Tensor, - bias: Option, - act: Option, - cublaslt: CublasLt, - ) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - - let act = act.map(|a| match a { - HiddenAct::Gelu => Activation::Gelu, - HiddenAct::Relu => Activation::Relu, - }); - - Self { - weight, - bias, - act, - cublaslt, - span, - } - } - - pub fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - - fused_matmul( - &self.weight, - x, - self.bias.as_ref(), - self.act.clone(), - self.cublaslt.clone(), - ) - } -} - #[derive(Debug)] struct BertEmbeddings { word_embeddings: Embedding, token_type_embeddings: Embedding, position_embeddings: Embedding, - layer_norm: FastLayerNorm, + layer_norm: LayerNorm, span: tracing::Span, } @@ -120,7 +39,11 @@ impl BertEmbeddings { )?, config.hidden_size, ), - layer_norm: FastLayerNorm::load(vb.pp("LayerNorm"), config)?, + layer_norm: LayerNorm::load( + vb.pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?, span: tracing::span!(tracing::Level::TRACE, "embeddings"), }) } @@ -148,7 +71,7 @@ impl BertEmbeddings { struct BertAttention { qkv_linear: Linear, dense: Linear, - layer_norm: FastLayerNorm, + layer_norm: LayerNorm, num_attention_heads: usize, attention_head_size: usize, @@ -158,7 +81,7 @@ struct BertAttention { } impl BertAttention { - pub fn load(vb: VarBuilder, config: &Config, cublaslt: CublasLt) -> Result { + pub fn load(vb: VarBuilder, config: &Config) -> Result { let attention_head_size = config.hidden_size / config.num_attention_heads; let all_head_size = config.num_attention_heads * attention_head_size; let hidden_size = config.hidden_size; @@ -179,7 +102,7 @@ impl BertAttention { let qkv_weight = Tensor::cat(&[&query_weight, &key_weight, &value_weight], 0)?; let qkv_bias = Tensor::cat(&[&query_bias, &key_bias, &value_bias], 0)?; - let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None, cublaslt.clone()); + let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None); let dense_weight = vb .pp("output") @@ -187,9 +110,13 @@ impl BertAttention { .get((hidden_size, hidden_size), "weight")?; let dense_bias = vb.pp("output").pp("dense").get(hidden_size, "bias")?; - let dense = Linear::new(dense_weight, Some(dense_bias), None, cublaslt.clone()); + let dense = Linear::new(dense_weight, Some(dense_bias), None); - let layer_norm = FastLayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?; + let layer_norm = LayerNorm::load( + vb.pp("output").pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32; @@ -248,13 +175,13 @@ struct BertLayer { attention: BertAttention, intermediate: Linear, output: Linear, - layer_norm: FastLayerNorm, + layer_norm: LayerNorm, span: tracing::Span, } impl BertLayer { - pub fn load(vb: VarBuilder, config: &Config, cublaslt: CublasLt) -> Result { - let attention = BertAttention::load(vb.pp("attention"), config, cublaslt.clone())?; + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let attention = BertAttention::load(vb.pp("attention"), config)?; let intermediate_weight = vb .pp("intermediate") @@ -268,7 +195,6 @@ impl BertLayer { intermediate_weight, Some(intermediate_bias), Some(config.hidden_act.clone()), - cublaslt.clone(), ); let output_weight = vb @@ -279,9 +205,13 @@ impl BertLayer { .pp("output") .pp("dense") .get(config.hidden_size, "bias")?; - let output = Linear::new(output_weight, Some(output_bias), None, cublaslt.clone()); + let output = Linear::new(output_weight, Some(output_bias), None); - let layer_norm = FastLayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?; + let layer_norm = LayerNorm::load( + vb.pp("output").pp("LayerNorm"), + config.hidden_size, + config.layer_norm_eps as f32, + )?; Ok(Self { attention, @@ -317,9 +247,9 @@ struct BertEncoder { } impl BertEncoder { - pub fn load(vb: VarBuilder, config: &Config, cublaslt: CublasLt) -> Result { + pub fn load(vb: VarBuilder, config: &Config) -> Result { let layers = (0..config.num_hidden_layers) - .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config, cublaslt.clone())) + .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config)) .collect::>>()?; let span = tracing::span!(tracing::Level::TRACE, "encoder"); @@ -370,11 +300,9 @@ impl FlashBertModel { candle::bail!("Pool type {pool:?} is not supported"); } - let cublaslt = CublasLt::new(vb.device()).unwrap(); - let (embeddings, encoder) = match ( BertEmbeddings::load(vb.pp("embeddings"), config), - BertEncoder::load(vb.pp("encoder"), config, cublaslt.clone()), + BertEncoder::load(vb.pp("encoder"), config), ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { @@ -382,16 +310,12 @@ impl FlashBertModel { if let (Ok(embeddings), Ok(encoder)) = ( BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config), - BertEncoder::load( - vb.pp(format!("{model_type}.encoder")), - config, - cublaslt.clone(), - ), + BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config), ) { (embeddings, encoder) } else if let (Ok(embeddings), Ok(encoder)) = ( BertEmbeddings::load(vb.pp("bert.embeddings"), config), - BertEncoder::load(vb.pp("bert.encoder"), config, cublaslt.clone()), + BertEncoder::load(vb.pp("bert.encoder"), config), ) { (embeddings, encoder) } else { diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index 3e4d7cf2..d2c896ce 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -13,14 +13,7 @@ pub enum DType { ))] Float16, // Float32 is not available on candle cuda - #[cfg(any( - feature = "python", - all( - feature = "candle", - not(feature = "flash-attn"), - not(feature = "flash-attn-v1") - ) - ))] + #[cfg(any(feature = "python", feature = "candle"))] Float32, // #[cfg(feature = "candle")] // Q6K, @@ -36,14 +29,7 @@ impl fmt::Display for DType { ))] DType::Float16 => write!(f, "float16"), // Float32 is not available on candle cuda - #[cfg(any( - feature = "python", - all( - feature = "candle", - not(feature = "flash-attn"), - not(feature = "flash-attn-v1") - ) - ))] + #[cfg(any(feature = "python", feature = "candle"))] DType::Float32 => write!(f, "float32"), // #[cfg(feature = "candle")] // DType::Q6K => write!(f, "q6k"), diff --git a/router/Cargo.toml b/router/Cargo.toml index f5acd0cf..d6f9a8d2 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -58,4 +58,5 @@ python = ["text-embeddings-backend/python"] candle = ["text-embeddings-backend/candle"] candle-cuda = ["candle", "text-embeddings-backend/flash-attn"] candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1"] +candle-cuda-volta = ["candle", "text-embeddings-backend/cuda"] static-linking = ["text-embeddings-backend/static-linking"]