Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-embeddings-inference"

[patch.crates-io]
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "4c8e6d36a4a4c31e2e4649ae5246226452a01fc1" }
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "8be6ff46e4a2014fb563570e0d206c09aea88152" }
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-core" }
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-nn" }
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-transformers" }
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ Options:
If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures

[env: DTYPE=]
[possible values: float16]
[possible values: float16, float32]

--pooling <POOLING>
Optionally control the pooling method.
Expand Down
1 change: 1 addition & 0 deletions backends/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ tracing = "^0.1"
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
python = ["dep:text-embeddings-backend-python"]
candle = ["dep:text-embeddings-backend-candle"]
cuda = ["text-embeddings-backend-candle?/cuda"]
mkl = ["text-embeddings-backend-candle?/mkl"]
mkl-dynamic = ["text-embeddings-backend-candle?/mkl-dynamic"]
accelerate = ["text-embeddings-backend-candle?/accelerate"]
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ candle-nn = { version = "0.3.0" }
candle-transformers = { version = "0.3.0" }
candle-flash-attn = { version = "0.3.0", optional = true }
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true }
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "07e1a5490211e25ed0d096a2b21d3c607666eaae", optional = true }
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "ffd246552c266640fab217f964a83960e07a66ec", optional = true }
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
lazy_static = "^1.4"
text-embeddings-backend-core = { path = "../core" }
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/flash_attn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::compute_cap::RUNTIME_COMPUTE_CAP;
use candle::Tensor;

#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments, unused)]
pub(crate) fn flash_attn_varlen(
q: &Tensor,
k: &Tensor,
Expand Down
8 changes: 8 additions & 0 deletions backends/candle/src/layers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#[allow(dead_code, unused)]
mod cublaslt;
mod layer_norm;
mod linear;

pub use cublaslt::CUBLASLT;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
104 changes: 104 additions & 0 deletions backends/candle/src/layers/cublaslt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use crate::layers::HiddenAct;
use candle::{Device, Result, Tensor};
use lazy_static::lazy_static;

#[cfg(feature = "cuda")]
use candle_cublaslt::{fused_batch_matmul, fused_matmul, Activation, CublasLt};

lazy_static! {
pub static ref CUBLASLT: Option<CublasLtWrapper> = {
match Device::cuda_if_available(0) {
Ok(device) => {
#[cfg(feature = "cuda")]
{
Some(CublasLtWrapper {
cublaslt: CublasLt::new(&device).unwrap(),
})
}
#[cfg(not(feature = "cuda"))]
{
None
}
}
Err(_) => None,
}
};
}

#[derive(Debug, Clone)]
pub struct CublasLtWrapper {
#[cfg(feature = "cuda")]
pub cublaslt: CublasLt,
}

impl CublasLtWrapper {
#[allow(clippy::too_many_arguments)]
pub fn matmul(
&self,
a: &Tensor,
b: &Tensor,
out: Option<&Tensor>,
alpha: Option<f32>,
beta: Option<f32>,
bias: Option<&Tensor>,
act: Option<HiddenAct>,
) -> Result<Tensor> {
#[cfg(feature = "cuda")]
{
let act = act.clone().map(|a| match a {
HiddenAct::Gelu => Activation::Gelu,
HiddenAct::Relu => Activation::Relu,
});

fused_matmul(
&a,
&b,
out,
alpha,
beta,
bias,
act.clone(),
self.cublaslt.clone(),
)
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
}

#[allow(clippy::too_many_arguments)]
pub fn batch_matmul(
&self,
a: &Tensor,
b: &Tensor,
out: Option<&Tensor>,
alpha: Option<f32>,
beta: Option<f32>,
bias: Option<&Tensor>,
act: Option<HiddenAct>,
) -> Result<Tensor> {
#[cfg(feature = "cuda")]
{
let act = act.clone().map(|a| match a {
HiddenAct::Gelu => Activation::Gelu,
HiddenAct::Relu => Activation::Relu,
});

fused_batch_matmul(
&a,
&b,
out,
alpha,
beta,
bias,
act.clone(),
self.cublaslt.clone(),
)
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
}
}
74 changes: 74 additions & 0 deletions backends/candle/src/layers/layer_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::VarBuilder;

#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
epsilon: f32,
span: tracing::Span,
}

impl LayerNorm {
pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result<Self> {
Ok(Self {
weight: vb
.get(hidden_size, "weight")
.or_else(|_| vb.get(hidden_size, "gamma"))?,
bias: vb
.get(hidden_size, "bias")
.or_else(|_| vb.get(hidden_size, "beta"))?,
epsilon,
span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
})
}

pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

match hidden_states.device() {
Device::Cpu => {
let hidden_states = hidden_states.add(residual)?;
let hidden_states_dtype = hidden_states.dtype();
let internal_dtype = match hidden_states_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = hidden_states.dim(D::Minus1)?;
let hidden_states = hidden_states.to_dtype(internal_dtype)?;
let mean_hidden_states =
(hidden_states.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let hidden_states = hidden_states.broadcast_sub(&mean_hidden_states)?;
let norm_hidden_states =
(hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let hidden_states_normed = hidden_states
.broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?;
let hidden_states = hidden_states_normed
.to_dtype(hidden_states_dtype)?
.broadcast_mul(&self.weight)?;
hidden_states.broadcast_add(&self.bias)
}
Device::Cuda(_) => {
#[cfg(feature = "cuda")]
{
use candle_layer_norm::fused_add_layer_norm;

let original_shape = hidden_states.shape();
let hidden_states = hidden_states.flatten_to(D::Minus2)?;
let residual = residual.flatten_to(D::Minus2)?;

let result = fused_add_layer_norm(
&hidden_states,
&residual,
&self.weight,
&self.bias,
self.epsilon,
)?;
result.reshape(original_shape)
}
#[cfg(not(feature = "cuda"))]
candle::bail!("`cuda` feature is not enabled")
}
}
}
}
73 changes: 73 additions & 0 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::layers::CUBLASLT;
use candle::{Device, Result, Tensor, D};
use serde::Deserialize;

#[derive(Debug, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum HiddenAct {
Gelu,
Relu,
}

#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
act: Option<HiddenAct>,
span: tracing::Span,
}

impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>, act: Option<HiddenAct>) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "linear");

Self {
weight,
bias,
act,
span,
}
}

pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

#[allow(unused)]
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), &*CUBLASLT) {
// fused matmul requires x to be dims2
let mut final_shape = x.dims().to_vec();
final_shape.pop();
final_shape.push(self.weight.dims()[0]);

let x = x.flatten_to(D::Minus2)?;
let result = cublaslt.matmul(
&self.weight,
&x,
None,
None,
None,
self.bias.as_ref(),
self.act.clone(),
)?;
result.reshape(final_shape)
} else {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
let x = match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}?;
if let Some(act) = &self.act {
match act {
HiddenAct::Gelu => x.gelu(),
HiddenAct::Relu => x.relu(),
}
} else {
Ok(x)
}
}
}
}
22 changes: 17 additions & 5 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
mod compute_cap;
#[cfg(feature = "cuda")]
mod flash_attn;
mod layers;
mod models;

#[cfg(feature = "cuda")]
use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_COMPUTE_CAP};
use crate::models::{BertModel, EmbeddingModel, QuantBertModel};
#[cfg(feature = "cuda")]
use crate::models::FlashBertModel;
use crate::models::{BertModel, EmbeddingModel, PositionEmbeddingType, QuantBertModel};
use candle::{DType, Device};
use candle_nn::VarBuilder;
use models::Config;
Expand Down Expand Up @@ -88,8 +91,6 @@ impl CandleBackend {
));
#[cfg(feature = "cuda")]
{
use crate::models::FlashBertModel;

// Get candle dtype
let dtype = if &dtype == "float32" {
Ok(DType::F32)
Expand Down Expand Up @@ -119,8 +120,19 @@ impl CandleBackend {
return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", *RUNTIME_COMPUTE_CAP, *COMPILE_COMPUTE_CAP)));
}

tracing::info!("Starting FlashBert model on Cuda");
Box::new(FlashBertModel::load(vb, &config, pool).s()?)
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& config.position_embedding_type == PositionEmbeddingType::Absolute
// Flash attention v1 precision problem with head_size == 32
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& !(*RUNTIME_COMPUTE_CAP == 75 && (config.hidden_size / config.num_attention_heads) == 32)
{
tracing::info!("Starting FlashBert model on Cuda");
Box::new(FlashBertModel::load(vb, &config, pool).s()?)
} else {
tracing::info!("Starting Bert model on Cuda");
Box::new(BertModel::load(vb, &config, pool).s()?)
}
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern crate accelerate_src;
mod bert;
mod bert_quant;

pub use bert::{BertModel, Config};
pub use bert::{BertModel, Config, PositionEmbeddingType};
pub use bert_quant::QuantBertModel;
use candle::{Result, Tensor};
use text_embeddings_backend_core::Batch;
Expand Down
Loading