@@ -11,13 +11,13 @@ use crate::compute_cap::{
1111 compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212} ;
1313use crate :: models:: {
14- BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15- MistralConfig , Model , NomicBertModel , NomicConfig ,
14+ BertConfig , BertModel , DistilBertConfig , DistilBertModel , GTEConfig , JinaBertModel ,
15+ JinaCodeBertModel , MistralConfig , Model , NomicBertModel , NomicConfig ,
1616} ;
1717#[ cfg( feature = "cuda" ) ]
1818use crate :: models:: {
19- FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel ,
20- FlashMistralModel , FlashNomicBertModel ,
19+ FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
20+ FlashJinaCodeBertModel , FlashMistralModel , FlashNomicBertModel ,
2121} ;
2222use anyhow:: Context ;
2323use candle:: { DType , Device } ;
@@ -57,6 +57,8 @@ enum Config {
5757 #[ serde( rename( deserialize = "nomic_bert" ) ) ]
5858 NomicBert ( NomicConfig ) ,
5959 Mistral ( MistralConfig ) ,
60+ #[ serde( rename = "new" ) ]
61+ Gte ( GTEConfig ) ,
6062}
6163
6264pub struct CandleBackend {
@@ -215,6 +217,10 @@ impl CandleBackend {
215217 "Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
216218 . to_string ( ) ,
217219 ) ) ,
220+ ( Config :: Gte ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
221+ "GTE is only supported on Cuda devices in fp16 with flash attention enabled"
222+ . to_string ( ) ,
223+ ) ) ,
218224 #[ cfg( feature = "cuda" ) ]
219225 ( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
220226 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
@@ -333,6 +339,17 @@ impl CandleBackend {
333339 FlashMistralModel :: load ( vb, & config, model_type) . s ( ) ?,
334340 ) )
335341 }
342+ #[ cfg( feature = "cuda" ) ]
343+ ( Config :: Gte ( config) , Device :: Cuda ( _) ) => {
344+ if dtype != DType :: F16
345+ || !cfg ! ( feature = "flash-attn" )
346+ || get_runtime_compute_cap ( ) . unwrap ( ) < 80
347+ {
348+ return Err ( BackendError :: Start ( "GTE is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
349+ }
350+ tracing:: info!( "Starting FlashGTE model on {:?}" , device) ;
351+ Ok ( Box :: new ( FlashGTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
352+ }
336353 } ;
337354
338355 Ok ( Self {
0 commit comments