@@ -260,9 +260,10 @@ def test_set_trainable_params(
260260 lora_attn_modules,
261261 apply_lora_to_mlp,
262262 apply_lora_to_output,
263- full_model_state_dict_keys,
264- lora_state_dict_keys,
265- base_model_state_dict_keys,
263+ base_missing,
264+ base_unexpected,
265+ lora_missing,
266+ lora_unexpected,
266267 expected
267268 """
268269 ),
@@ -271,188 +272,117 @@ def test_set_trainable_params(
271272 ["q_proj" , "k_proj" ],
272273 False ,
273274 False ,
274- ["q_proj.lora_a.weight" , "dummy_param.weight" ],
275275 ["q_proj.lora_a.weight" ],
276+ [],
276277 ["dummy_param.weight" ],
278+ [],
277279 "" ,
278280 ),
279- (
280- ["v_proj" ],
281- False ,
282- False ,
283- ["param_a" , "param_b" ],
284- None ,
285- ["param_a" , "param_b" ],
286- "" ,
287- ),
281+ (["v_proj" ], False , False , [], [], ["param_a" , "param_b" ], [], "" ),
288282 (
289283 ["output_proj" ],
290284 False ,
291285 True ,
292- ["output_proj.weight" , "output_proj.lora_a.weight" ],
293286 ["output_proj.lora_a.weight" ],
287+ [],
294288 ["output_proj.weight" ],
289+ [],
295290 "" ,
296291 ),
297- (["q_proj" ], False , False , ["param_a" ], [], [], "Missing non-LoRA" ),
298292 (
299- ["k_proj" , "output_proj " ],
293+ ["q_proj " ],
300294 False ,
301- True ,
302- ["k_proj.lora_a.weight" , " param_a" ],
303- ["k_proj.lora_a.weight" , "param_a" ],
295+ False ,
296+ ["param_a" ],
297+ [],
304298 ["param_a" ],
305- "found in LoRA" ,
299+ [],
300+ "Missing non-LoRA" ,
306301 ),
307302 (
308- ["k_proj" ],
309- False ,
303+ ["k_proj" , "output_proj" ],
310304 False ,
311- ["k_proj.lora_a.weight" ],
305+ True ,
306+ [],
312307 [],
313308 ["k_proj.lora_a.weight" ],
314- "found in base model" ,
309+ [],
310+ "Missing LoRA key" ,
315311 ),
316312 (
317- ["k_proj" ],
318- False ,
313+ ["q_proj" , " k_proj" ],
314+ True ,
319315 False ,
320- ["k_proj.lora_a.weight" ],
316+ ["k_proj.lora" ],
317+ [],
318+ ["q_proj.lora" ],
321319 [],
322- None ,
323320 "Missing LoRA" ,
324321 ),
325- (["q_proj" ], False , False , [], ["a" ], ["a" ], "overlapping" ),
326- (
327- ["v_proj" ],
328- False ,
329- False ,
330- ["dummy_param.weight" ],
331- ["v_proj.lora_a.weight" ],
332- ["dummy_param.weight" ],
333- "Extra" ,
334- ),
335322 (
336- ["w1 " , "w2" , "w3 " ],
323+ ["q_proj " , "k_proj " ],
337324 True ,
338325 False ,
339- ["w1.lora_a.weight" , "w2.weight" , "q_proj.weight" ],
340- ["w1.lora_a.weight" ],
341- ["q_proj.weight" ],
342- "Missing non-LoRA key" ,
326+ ["k_proj.lora" ],
327+ [],
328+ ["q_proj.magnitude" ],
329+ [],
330+ "Missing LoRA" ,
343331 ),
344332 (
345- ["q_proj" , "output" ],
346- False ,
333+ ["q_proj" , "k_proj" ],
347334 True ,
348- [
349- "q_proj.lora_a" ,
350- "output.weight" ,
351- "output.lora_a" ,
352- "output_proj.lora_b" ,
353- ],
354- ["q_proj.lora_a" , "output.lora_a" , "output_proj.lora_b" ],
355- ["output.weight" ],
356- "Missing non-LoRA key" ,
357- ),
358- (
359- ["q_proj" , "v_proj" ],
360335 False ,
361- False ,
362- "lora_llama2_model_all_keys" ,
363- "lora_llama2_expected_adapter_keys" ,
364- "lora_llama2_expected_base_model_keys" ,
365- "" ,
336+ [ "output_proj.lora" ] ,
337+ [] ,
338+ [ "q_proj.lora" ] ,
339+ [] ,
340+ "Missing non-LoRA " ,
366341 ),
367342 (
368- ["q_proj" , "v_proj " ],
369- False ,
343+ ["q_proj" , "k_proj " ],
344+ True ,
370345 False ,
371- "dora_llama2_model_all_keys" ,
372- "dora_llama2_expected_adapter_keys" ,
373- "lora_llama2_expected_base_model_keys" ,
374- "" ,
375- ),
376- ],
377- )
378- def test_validate_lora_state_dict (
379- self ,
380- request ,
381- lora_attn_modules ,
382- apply_lora_to_mlp ,
383- apply_lora_to_output ,
384- full_model_state_dict_keys ,
385- lora_state_dict_keys ,
386- base_model_state_dict_keys ,
387- expected ,
388- ):
389- if isinstance (full_model_state_dict_keys , str ):
390- full_model_state_dict_keys = request .getfixturevalue (
391- full_model_state_dict_keys
392- )
393- if isinstance (lora_state_dict_keys , str ):
394- lora_state_dict_keys = request .getfixturevalue (lora_state_dict_keys )
395- if isinstance (base_model_state_dict_keys , str ):
396- base_model_state_dict_keys = request .getfixturevalue (
397- base_model_state_dict_keys
398- )
399- if expected :
400- with pytest .raises (AssertionError , match = expected ):
401- validate_missing_and_unexpected_for_lora (
402- lora_attn_modules ,
403- apply_lora_to_mlp ,
404- apply_lora_to_output ,
405- full_model_state_dict_keys = full_model_state_dict_keys ,
406- lora_state_dict_keys = lora_state_dict_keys ,
407- base_model_state_dict_keys = base_model_state_dict_keys ,
408- )
409- else :
410- validate_missing_and_unexpected_for_lora (
411- lora_attn_modules ,
412- apply_lora_to_mlp ,
413- apply_lora_to_output ,
414- full_model_state_dict_keys = full_model_state_dict_keys ,
415- lora_state_dict_keys = lora_state_dict_keys ,
416- base_model_state_dict_keys = base_model_state_dict_keys ,
417- )
418-
419- @pytest .mark .parametrize (
420- (
421- """
422- base_missing,
423- base_unexpected,
424- lora_missing,
425- lora_unexpected,
426- expected
427- """
428- ),
429- [
430- (["k_proj.lora" ], [], ["q_proj.lora" ], [], "Missing LoRA" ),
431- (["k_proj.lora" ], [], ["q_proj.magnitude" ], [], "Missing LoRA" ),
432- (["output_proj.lora" ], [], ["q_proj.lora" ], [], "Missing non-LoRA" ),
433- (
434346 ["k_proj.lora" ],
435347 ["output.weight" ],
436348 ["q_proj.base_weight" ],
437349 [],
438350 "loading base model" ,
439351 ),
440352 (
353+ ["q_proj" , "k_proj" ],
354+ True ,
355+ False ,
441356 ["k_proj.lora" ],
442357 [],
443358 ["q_proj.base_weight" ],
444359 ["output.weight" ],
445360 "loading adapter" ,
446361 ),
447- (["k_proj.lora" ], [], ["q_proj.base_weight" ], [], "" ),
362+ (
363+ ["q_proj" , "k_proj" ],
364+ True ,
365+ False ,
366+ ["k_proj.lora" ],
367+ [],
368+ ["q_proj.base_weight" ],
369+ [],
370+ "" ,
371+ ),
448372 ],
449373 )
450374 def test_validate_missing_and_unexpected_for_lora (
451- self , base_missing , base_unexpected , lora_missing , lora_unexpected , expected
375+ self ,
376+ lora_attn_modules ,
377+ apply_lora_to_mlp ,
378+ apply_lora_to_output ,
379+ base_missing ,
380+ base_unexpected ,
381+ lora_missing ,
382+ lora_unexpected ,
383+ expected ,
452384 ):
453- lora_attn_modules = ["q_proj" , "k_proj" ]
454- apply_lora_to_mlp = True
455- apply_lora_to_output = False
385+
456386 if expected :
457387 with pytest .raises (AssertionError , match = expected ):
458388 validate_missing_and_unexpected_for_lora (
0 commit comments