Skip to content

Commit

Permalink
Fix Gemma 7B LoRA checkpoint save (#1169)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Jul 12, 2024
1 parent f292b14 commit cc92fa0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def tune_to_peft_adapter_weights(
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
):
converted_state_dict = {}
full_mapping = {}
Expand All @@ -266,7 +267,8 @@ def tune_to_peft_adapter_weights(
}
)

head_dim = dim // num_heads
if head_dim is None:
head_dim = dim // num_heads

def _permute_lora_matrix(t, n_heads):
rank = t.shape[-1]
Expand Down
1 change: 1 addition & 0 deletions torchtune/utils/_checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ def save_checkpoint(
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)
peft_output_path = Path.joinpath(
self._output_dir, "adapter_model"
Expand Down

0 comments on commit cc92fa0

Please sign in to comment.