Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ PEFT Components
peft.set_trainable_params
peft.get_adapter_state_dict
peft.validate_missing_and_unexpected_for_lora
peft.validate_state_dict_for_lora
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi Nov 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 209 in the LoRA finetune tutorial:

.. note::
    Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in
    the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via e.g.
    :func:`validate_state_dict_for_lora() <torchtune.modules.peft.validate_state_dict_for_lora>` or
    :func:`validate_missing_and_unexpected_for_lora() <torchtune.modules.peft.validate_missing_and_unexpected_for_lora>`.

Needs to be updated

peft.disable_adapter


Expand Down
3 changes: 0 additions & 3 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,9 @@ Utilities for enabling and working with distributed training.
:toctree: generated/
:nosignatures:

FSDPPolicyType
init_distributed
is_distributed
get_world_size_and_rank
get_full_finetune_fsdp_wrap_policy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a reference to this in the QAT tutorial too.

lora_fsdp_wrap_policy
gather_cpu_state_dict

.. _ac_label:
Expand Down
3 changes: 1 addition & 2 deletions docs/source/tutorials/lora_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ model without any wrappers or custom checkpoint conversion logic.

.. note::
Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in
the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via e.g.
:func:`validate_state_dict_for_lora() <torchtune.modules.peft.validate_state_dict_for_lora>` or
the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via
:func:`validate_missing_and_unexpected_for_lora() <torchtune.modules.peft.validate_missing_and_unexpected_for_lora>`.

Once we've loaded the base model weights, we also want to set only LoRA parameters to trainable.
Expand Down
5 changes: 0 additions & 5 deletions docs/source/tutorials/qat_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,6 @@ modifications accordingly:
fake_quant_after_n_steps: 1000
memory_efficient_fsdp_wrap: False

.. note::

QAT in torchtune is currently not compatible with `memory_efficient_fsdp_wrap <https://pytorch.org/torchtune/stable/generated/torchtune.utils.get_full_finetune_fsdp_wrap_policy.html#torchtune.utils.get_full_finetune_fsdp_wrap_policy>`_.
This is a known issue and will be fixed in a future torchtune version.

Empirically, we observed that disabling fake quantization for the first N steps
led to better results, presumably because doing so allows the weights to stabilize
before we start introducing quantization noise to the fine-tuning process.
Expand Down
14 changes: 0 additions & 14 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -271,19 +270,6 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
)

base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
Expand Down
195 changes: 62 additions & 133 deletions tests/torchtune/modules/peft/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)

N_LAYERS = 3
Expand Down Expand Up @@ -261,9 +260,10 @@ def test_set_trainable_params(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys,
lora_state_dict_keys,
base_model_state_dict_keys,
base_missing,
base_unexpected,
lora_missing,
lora_unexpected,
expected
"""
),
Expand All @@ -272,188 +272,117 @@ def test_set_trainable_params(
["q_proj", "k_proj"],
False,
False,
["q_proj.lora_a.weight", "dummy_param.weight"],
["q_proj.lora_a.weight"],
[],
["dummy_param.weight"],
[],
"",
),
(
["v_proj"],
False,
False,
["param_a", "param_b"],
None,
["param_a", "param_b"],
"",
),
(["v_proj"], False, False, [], [], ["param_a", "param_b"], [], ""),
(
["output_proj"],
False,
True,
["output_proj.weight", "output_proj.lora_a.weight"],
["output_proj.lora_a.weight"],
[],
["output_proj.weight"],
[],
"",
),
(["q_proj"], False, False, ["param_a"], [], [], "Missing non-LoRA"),
(
["k_proj", "output_proj"],
["q_proj"],
False,
True,
["k_proj.lora_a.weight", "param_a"],
["k_proj.lora_a.weight", "param_a"],
False,
["param_a"],
[],
["param_a"],
"found in LoRA",
[],
"Missing non-LoRA",
),
(
["k_proj"],
False,
["k_proj", "output_proj"],
False,
["k_proj.lora_a.weight"],
True,
[],
[],
["k_proj.lora_a.weight"],
"found in base model",
[],
"Missing LoRA key",
),
(
["k_proj"],
False,
["q_proj", "k_proj"],
True,
False,
["k_proj.lora_a.weight"],
["k_proj.lora"],
[],
["q_proj.lora"],
[],
None,
"Missing LoRA",
),
(["q_proj"], False, False, [], ["a"], ["a"], "overlapping"),
(
["v_proj"],
False,
False,
["dummy_param.weight"],
["v_proj.lora_a.weight"],
["dummy_param.weight"],
"Extra",
),
(
["w1", "w2", "w3"],
["q_proj", "k_proj"],
True,
False,
["w1.lora_a.weight", "w2.weight", "q_proj.weight"],
["w1.lora_a.weight"],
["q_proj.weight"],
"Missing non-LoRA key",
["k_proj.lora"],
[],
["q_proj.magnitude"],
[],
"Missing LoRA",
),
(
["q_proj", "output"],
False,
["q_proj", "k_proj"],
True,
[
"q_proj.lora_a",
"output.weight",
"output.lora_a",
"output_proj.lora_b",
],
["q_proj.lora_a", "output.lora_a", "output_proj.lora_b"],
["output.weight"],
"Missing non-LoRA key",
),
(
["q_proj", "v_proj"],
False,
False,
"lora_llama2_model_all_keys",
"lora_llama2_expected_adapter_keys",
"lora_llama2_expected_base_model_keys",
"",
["output_proj.lora"],
[],
["q_proj.lora"],
[],
"Missing non-LoRA",
),
(
["q_proj", "v_proj"],
False,
["q_proj", "k_proj"],
True,
False,
"dora_llama2_model_all_keys",
"dora_llama2_expected_adapter_keys",
"lora_llama2_expected_base_model_keys",
"",
),
],
)
def test_validate_lora_state_dict(
self,
request,
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys,
lora_state_dict_keys,
base_model_state_dict_keys,
expected,
):
if isinstance(full_model_state_dict_keys, str):
full_model_state_dict_keys = request.getfixturevalue(
full_model_state_dict_keys
)
if isinstance(lora_state_dict_keys, str):
lora_state_dict_keys = request.getfixturevalue(lora_state_dict_keys)
if isinstance(base_model_state_dict_keys, str):
base_model_state_dict_keys = request.getfixturevalue(
base_model_state_dict_keys
)
if expected:
with pytest.raises(AssertionError, match=expected):
validate_state_dict_for_lora(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys=full_model_state_dict_keys,
lora_state_dict_keys=lora_state_dict_keys,
base_model_state_dict_keys=base_model_state_dict_keys,
)
else:
validate_state_dict_for_lora(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys=full_model_state_dict_keys,
lora_state_dict_keys=lora_state_dict_keys,
base_model_state_dict_keys=base_model_state_dict_keys,
)

@pytest.mark.parametrize(
(
"""
base_missing,
base_unexpected,
lora_missing,
lora_unexpected,
expected
"""
),
[
(["k_proj.lora"], [], ["q_proj.lora"], [], "Missing LoRA"),
(["k_proj.lora"], [], ["q_proj.magnitude"], [], "Missing LoRA"),
(["output_proj.lora"], [], ["q_proj.lora"], [], "Missing non-LoRA"),
(
["k_proj.lora"],
["output.weight"],
["q_proj.base_weight"],
[],
"loading base model",
),
(
["q_proj", "k_proj"],
True,
False,
["k_proj.lora"],
[],
["q_proj.base_weight"],
["output.weight"],
"loading adapter",
),
(["k_proj.lora"], [], ["q_proj.base_weight"], [], ""),
(
["q_proj", "k_proj"],
True,
False,
["k_proj.lora"],
[],
["q_proj.base_weight"],
[],
"",
),
],
)
def test_validate_missing_and_unexpected_for_lora(
self, base_missing, base_unexpected, lora_missing, lora_unexpected, expected
self,
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
base_missing,
base_unexpected,
lora_missing,
lora_unexpected,
expected,
):
lora_attn_modules = ["q_proj", "k_proj"]
apply_lora_to_mlp = True
apply_lora_to_output = False

if expected:
with pytest.raises(AssertionError, match=expected):
validate_missing_and_unexpected_for_lora(
Expand Down
Loading
Loading