Skip to content

Commit ba9c60e

Browse files
committed
delete unused test, migrate some test cases to key validation unit test
1 parent 1c3b4ef commit ba9c60e

File tree

1 file changed

+62
-132
lines changed

1 file changed

+62
-132
lines changed

tests/torchtune/modules/peft/test_utils.py

Lines changed: 62 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,10 @@ def test_set_trainable_params(
260260
lora_attn_modules,
261261
apply_lora_to_mlp,
262262
apply_lora_to_output,
263-
full_model_state_dict_keys,
264-
lora_state_dict_keys,
265-
base_model_state_dict_keys,
263+
base_missing,
264+
base_unexpected,
265+
lora_missing,
266+
lora_unexpected,
266267
expected
267268
"""
268269
),
@@ -271,188 +272,117 @@ def test_set_trainable_params(
271272
["q_proj", "k_proj"],
272273
False,
273274
False,
274-
["q_proj.lora_a.weight", "dummy_param.weight"],
275275
["q_proj.lora_a.weight"],
276+
[],
276277
["dummy_param.weight"],
278+
[],
277279
"",
278280
),
279-
(
280-
["v_proj"],
281-
False,
282-
False,
283-
["param_a", "param_b"],
284-
None,
285-
["param_a", "param_b"],
286-
"",
287-
),
281+
(["v_proj"], False, False, [], [], ["param_a", "param_b"], [], ""),
288282
(
289283
["output_proj"],
290284
False,
291285
True,
292-
["output_proj.weight", "output_proj.lora_a.weight"],
293286
["output_proj.lora_a.weight"],
287+
[],
294288
["output_proj.weight"],
289+
[],
295290
"",
296291
),
297-
(["q_proj"], False, False, ["param_a"], [], [], "Missing non-LoRA"),
298292
(
299-
["k_proj", "output_proj"],
293+
["q_proj"],
300294
False,
301-
True,
302-
["k_proj.lora_a.weight", "param_a"],
303-
["k_proj.lora_a.weight", "param_a"],
295+
False,
296+
["param_a"],
297+
[],
304298
["param_a"],
305-
"found in LoRA",
299+
[],
300+
"Missing non-LoRA",
306301
),
307302
(
308-
["k_proj"],
309-
False,
303+
["k_proj", "output_proj"],
310304
False,
311-
["k_proj.lora_a.weight"],
305+
True,
306+
[],
312307
[],
313308
["k_proj.lora_a.weight"],
314-
"found in base model",
309+
[],
310+
"Missing LoRA key",
315311
),
316312
(
317-
["k_proj"],
318-
False,
313+
["q_proj", "k_proj"],
314+
True,
319315
False,
320-
["k_proj.lora_a.weight"],
316+
["k_proj.lora"],
317+
[],
318+
["q_proj.lora"],
321319
[],
322-
None,
323320
"Missing LoRA",
324321
),
325-
(["q_proj"], False, False, [], ["a"], ["a"], "overlapping"),
326-
(
327-
["v_proj"],
328-
False,
329-
False,
330-
["dummy_param.weight"],
331-
["v_proj.lora_a.weight"],
332-
["dummy_param.weight"],
333-
"Extra",
334-
),
335322
(
336-
["w1", "w2", "w3"],
323+
["q_proj", "k_proj"],
337324
True,
338325
False,
339-
["w1.lora_a.weight", "w2.weight", "q_proj.weight"],
340-
["w1.lora_a.weight"],
341-
["q_proj.weight"],
342-
"Missing non-LoRA key",
326+
["k_proj.lora"],
327+
[],
328+
["q_proj.magnitude"],
329+
[],
330+
"Missing LoRA",
343331
),
344332
(
345-
["q_proj", "output"],
346-
False,
333+
["q_proj", "k_proj"],
347334
True,
348-
[
349-
"q_proj.lora_a",
350-
"output.weight",
351-
"output.lora_a",
352-
"output_proj.lora_b",
353-
],
354-
["q_proj.lora_a", "output.lora_a", "output_proj.lora_b"],
355-
["output.weight"],
356-
"Missing non-LoRA key",
357-
),
358-
(
359-
["q_proj", "v_proj"],
360335
False,
361-
False,
362-
"lora_llama2_model_all_keys",
363-
"lora_llama2_expected_adapter_keys",
364-
"lora_llama2_expected_base_model_keys",
365-
"",
336+
["output_proj.lora"],
337+
[],
338+
["q_proj.lora"],
339+
[],
340+
"Missing non-LoRA",
366341
),
367342
(
368-
["q_proj", "v_proj"],
369-
False,
343+
["q_proj", "k_proj"],
344+
True,
370345
False,
371-
"dora_llama2_model_all_keys",
372-
"dora_llama2_expected_adapter_keys",
373-
"lora_llama2_expected_base_model_keys",
374-
"",
375-
),
376-
],
377-
)
378-
def test_validate_lora_state_dict(
379-
self,
380-
request,
381-
lora_attn_modules,
382-
apply_lora_to_mlp,
383-
apply_lora_to_output,
384-
full_model_state_dict_keys,
385-
lora_state_dict_keys,
386-
base_model_state_dict_keys,
387-
expected,
388-
):
389-
if isinstance(full_model_state_dict_keys, str):
390-
full_model_state_dict_keys = request.getfixturevalue(
391-
full_model_state_dict_keys
392-
)
393-
if isinstance(lora_state_dict_keys, str):
394-
lora_state_dict_keys = request.getfixturevalue(lora_state_dict_keys)
395-
if isinstance(base_model_state_dict_keys, str):
396-
base_model_state_dict_keys = request.getfixturevalue(
397-
base_model_state_dict_keys
398-
)
399-
if expected:
400-
with pytest.raises(AssertionError, match=expected):
401-
validate_missing_and_unexpected_for_lora(
402-
lora_attn_modules,
403-
apply_lora_to_mlp,
404-
apply_lora_to_output,
405-
full_model_state_dict_keys=full_model_state_dict_keys,
406-
lora_state_dict_keys=lora_state_dict_keys,
407-
base_model_state_dict_keys=base_model_state_dict_keys,
408-
)
409-
else:
410-
validate_missing_and_unexpected_for_lora(
411-
lora_attn_modules,
412-
apply_lora_to_mlp,
413-
apply_lora_to_output,
414-
full_model_state_dict_keys=full_model_state_dict_keys,
415-
lora_state_dict_keys=lora_state_dict_keys,
416-
base_model_state_dict_keys=base_model_state_dict_keys,
417-
)
418-
419-
@pytest.mark.parametrize(
420-
(
421-
"""
422-
base_missing,
423-
base_unexpected,
424-
lora_missing,
425-
lora_unexpected,
426-
expected
427-
"""
428-
),
429-
[
430-
(["k_proj.lora"], [], ["q_proj.lora"], [], "Missing LoRA"),
431-
(["k_proj.lora"], [], ["q_proj.magnitude"], [], "Missing LoRA"),
432-
(["output_proj.lora"], [], ["q_proj.lora"], [], "Missing non-LoRA"),
433-
(
434346
["k_proj.lora"],
435347
["output.weight"],
436348
["q_proj.base_weight"],
437349
[],
438350
"loading base model",
439351
),
440352
(
353+
["q_proj", "k_proj"],
354+
True,
355+
False,
441356
["k_proj.lora"],
442357
[],
443358
["q_proj.base_weight"],
444359
["output.weight"],
445360
"loading adapter",
446361
),
447-
(["k_proj.lora"], [], ["q_proj.base_weight"], [], ""),
362+
(
363+
["q_proj", "k_proj"],
364+
True,
365+
False,
366+
["k_proj.lora"],
367+
[],
368+
["q_proj.base_weight"],
369+
[],
370+
"",
371+
),
448372
],
449373
)
450374
def test_validate_missing_and_unexpected_for_lora(
451-
self, base_missing, base_unexpected, lora_missing, lora_unexpected, expected
375+
self,
376+
lora_attn_modules,
377+
apply_lora_to_mlp,
378+
apply_lora_to_output,
379+
base_missing,
380+
base_unexpected,
381+
lora_missing,
382+
lora_unexpected,
383+
expected,
452384
):
453-
lora_attn_modules = ["q_proj", "k_proj"]
454-
apply_lora_to_mlp = True
455-
apply_lora_to_output = False
385+
456386
if expected:
457387
with pytest.raises(AssertionError, match=expected):
458388
validate_missing_and_unexpected_for_lora(

0 commit comments

Comments
 (0)