55import torch
66import torch .nn as nn
77import torch .nn .functional as F
8+ from executorch .examples .models .llama .lora import LoRALinear
89from executorch .examples .models .llama .model_args import ModelArgs
910from executorch .examples .models .llama .norm import RMSNorm
1011from executorch .examples .models .llama .rope import Rope
@@ -325,7 +326,20 @@ def update(
325326
326327@register_attention ("mha" )
327328class AttentionMHA (Attention ):
328- def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
329+ def __init__ (
330+ self ,
331+ args : ModelArgs ,
332+ layer_id : int ,
333+ rope : Rope ,
334+ ):
335+ """
336+ Multi-head attention layer.
337+
338+ Args:
339+ args (ModelArgs): Model configuration parameters.
340+ layer_id (int): Layer index.
341+ rope (Rope): Rotary position embedding module.
342+ """
329343 super ().__init__ ()
330344 self .use_kv_cache = args .use_kv_cache
331345 self .n_heads = args .n_heads
@@ -350,16 +364,60 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
350364 self .q_norm_fn = RMSNorm (q_norm_dim , eps = args .norm_eps )
351365 self .k_norm_fn = RMSNorm (k_norm_dim , eps = args .norm_eps )
352366
353- self .wq = nn .Linear (
354- self .dim , self .n_heads * self .head_dim , bias = self .attention_qkv_bias
367+ self .wq = (
368+ LoRALinear (
369+ in_dim = args .dim ,
370+ out_dim = args .n_heads * args .head_dim ,
371+ rank = args .r ,
372+ alpha = args .lora_alpha ,
373+ dropout = 0.0 ,
374+ use_bias = args .attention_qkv_bias ,
375+ )
376+ if args .target_modules is not None and "q_proj" in args .target_modules
377+ else nn .Linear (
378+ self .dim , self .n_heads * self .head_dim , bias = self .attention_qkv_bias
379+ )
355380 )
356- self .wk = nn .Linear (
357- self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
381+ self .wk = (
382+ LoRALinear (
383+ in_dim = args .dim ,
384+ out_dim = args .n_kv_heads * args .head_dim ,
385+ rank = args .r ,
386+ alpha = args .lora_alpha ,
387+ dropout = 0.0 ,
388+ use_bias = args .attention_qkv_bias ,
389+ )
390+ if args .target_modules is not None and "k_proj" in args .target_modules
391+ else nn .Linear (
392+ self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
393+ )
358394 )
359- self .wv = nn .Linear (
360- self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
395+ self .wv = (
396+ LoRALinear (
397+ in_dim = args .dim ,
398+ out_dim = args .n_kv_heads * args .head_dim ,
399+ rank = args .r ,
400+ alpha = args .lora_alpha ,
401+ dropout = 0.0 ,
402+ use_bias = args .attention_qkv_bias ,
403+ )
404+ if args .target_modules is not None and "v_proj" in args .target_modules
405+ else nn .Linear (
406+ self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
407+ )
408+ )
409+ self .wo = (
410+ LoRALinear (
411+ in_dim = args .n_kv_heads * args .head_dim ,
412+ out_dim = args .dim ,
413+ rank = args .r ,
414+ alpha = args .lora_alpha ,
415+ dropout = 0.0 ,
416+ use_bias = args .attention_qkv_bias ,
417+ )
418+ if args .target_modules is not None and "output_proj" in args .target_modules
419+ else nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
361420 )
362- self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
363421
364422 self .layer_id = layer_id
365423
0 commit comments