@@ -50,23 +50,27 @@ def inputs(self, in_dim) -> torch.Tensor:
5050
5151 @pytest .fixture
5252 def lora_linear (self , in_dim , out_dim ) -> LoRALinear :
53- lora_linear = LoRALinear (
54- in_dim = in_dim ,
55- out_dim = out_dim ,
56- rank = RANK ,
57- alpha = ALPHA ,
58- use_bias = True ,
59- )
60- fixed_init_model (lora_linear )
61- return lora_linear
53+ def create_lora_linear (use_bias , dtype , in_dim = in_dim , out_dim = out_dim ):
54+ with training .set_default_dtype (dtype ):
55+ lora_linear = LoRALinear (
56+ in_dim = in_dim ,
57+ out_dim = out_dim ,
58+ rank = RANK ,
59+ alpha = ALPHA ,
60+ use_bias = use_bias ,
61+ )
62+ fixed_init_model (lora_linear )
63+ return lora_linear
64+
65+ return create_lora_linear
6266
6367 @pytest .fixture
6468 def qlora_linear (self ):
65- def create_qlora_linear (use_bias , dtype ):
69+ def create_qlora_linear (use_bias , dtype , in_dim = 512 , out_dim = 512 ):
6670 with training .set_default_dtype (dtype ):
6771 qlora_linear = LoRALinear (
68- in_dim = 512 ,
69- out_dim = 512 ,
72+ in_dim = in_dim ,
73+ out_dim = out_dim ,
7074 rank = RANK ,
7175 alpha = ALPHA ,
7276 use_bias = use_bias ,
@@ -95,6 +99,7 @@ def set_dummy_weights_for_merge(self, lora_module):
9599 lora_module .lora_b .weight [32 , 1 ] = 12
96100
97101 def test_forward (self , inputs , lora_linear , out_dim ) -> None :
102+ lora_linear = lora_linear (use_bias = True , dtype = torch .float32 )
98103 expected = torch .tensor (EXPECTED_VAL )
99104 actual = lora_linear (inputs )
100105 assert actual .shape == (BSZ , SEQ_LEN , out_dim )
@@ -115,18 +120,13 @@ def test_lora_weight_nf4_when_quantized(self, use_bias, qlora_linear):
115120 "use_bias, dtype" ,
116121 [(False , torch .bfloat16 ), (True , torch .float32 ), (False , torch .float32 )],
117122 )
118- def test_qlora_parity (self , use_bias , dtype , qlora_linear ):
119- qlora_linear = qlora_linear (use_bias = use_bias , dtype = dtype )
120- with training .set_default_dtype (dtype ):
121- lora_linear = LoRALinear (
122- in_dim = 512 ,
123- out_dim = 512 ,
124- rank = RANK ,
125- alpha = ALPHA ,
126- use_bias = use_bias ,
127- quantize_base = False ,
128- )
129- fixed_init_model (lora_linear , dtype = torch .bfloat16 )
123+ def test_qlora_parity (self , use_bias , dtype , qlora_linear , lora_linear ):
124+ qlora_linear = qlora_linear (
125+ use_bias = use_bias , dtype = dtype , in_dim = 512 , out_dim = 512
126+ )
127+ lora_linear = lora_linear (
128+ use_bias = use_bias , dtype = dtype , in_dim = 512 , out_dim = 512
129+ )
130130
131131 # set weight of lora_linear to unquantized weight of qlora_linear and check
132132 # parity.
0 commit comments