@@ -12,12 +12,12 @@ use crate::compute_cap::{
1212} ;
1313use crate :: models:: {
1414 BertConfig , BertModel , DistilBertConfig , DistilBertModel , GTEConfig , JinaBertModel ,
15- JinaCodeBertModel , MistralConfig , Model , NomicBertModel , NomicConfig ,
15+ JinaCodeBertModel , MistralConfig , Model , NomicBertModel , NomicConfig , Qwen2Config ,
1616} ;
1717#[ cfg( feature = "cuda" ) ]
1818use crate :: models:: {
1919 FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
20- FlashJinaCodeBertModel , FlashMistralModel , FlashNomicBertModel ,
20+ FlashJinaCodeBertModel , FlashMistralModel , FlashNomicBertModel , FlashQwen2Model ,
2121} ;
2222use anyhow:: Context ;
2323use candle:: { DType , Device } ;
@@ -59,6 +59,7 @@ enum Config {
5959 Mistral ( MistralConfig ) ,
6060 #[ serde( rename = "new" ) ]
6161 Gte ( GTEConfig ) ,
62+ Qwen2 ( Qwen2Config ) ,
6263}
6364
6465pub struct CandleBackend {
@@ -221,6 +222,10 @@ impl CandleBackend {
221222 "GTE is only supported on Cuda devices in fp16 with flash attention enabled"
222223 . to_string ( ) ,
223224 ) ) ,
225+ ( Config :: Qwen2 ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
226+ "Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
227+ . to_string ( ) ,
228+ ) ) ,
224229 #[ cfg( feature = "cuda" ) ]
225230 ( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
226231 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
@@ -342,14 +347,25 @@ impl CandleBackend {
342347 #[ cfg( feature = "cuda" ) ]
343348 ( Config :: Gte ( config) , Device :: Cuda ( _) ) => {
344349 if dtype != DType :: F16
345- || !cfg ! ( feature = "flash-attn" )
346- || get_runtime_compute_cap ( ) . unwrap ( ) < 80
350+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
347351 {
348- return Err ( BackendError :: Start ( "GTE is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
352+ return Err ( BackendError :: Start ( "GTE is only supported on Cuda devices in fp16 with flash attention enabled" . to_string ( ) ) ) ;
349353 }
350354 tracing:: info!( "Starting FlashGTE model on {:?}" , device) ;
351355 Ok ( Box :: new ( FlashGTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
352356 }
357+ #[ cfg( feature = "cuda" ) ]
358+ ( Config :: Qwen2 ( config) , Device :: Cuda ( _) ) => {
359+ if dtype != DType :: F16
360+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
361+ {
362+ return Err ( BackendError :: Start ( "Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
363+ }
364+ tracing:: info!( "Starting FlashQwen2 model on {:?}" , device) ;
365+ Ok ( Box :: new (
366+ FlashQwen2Model :: load ( vb, & config, model_type) . s ( ) ?,
367+ ) )
368+ }
353369 } ;
354370
355371 Ok ( Self {
0 commit comments