Skip to content

Commit b197368

Browse files
committed
cleanup debug code, v nice state dict hook workaround
1 parent 46f6dd0 commit b197368

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

torchtune/modules/model_fusion/_fusion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,11 @@ def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
232232
"""Apply extra "embedding" prefix to the state_dict key to
233233
account for the FusionEmbedding wrapping.
234234
"""
235-
key = prefix + "weight"
236-
new_key = prefix + "embedding.weight"
237-
state_dict[new_key] = state_dict[key]
238-
del state_dict[key]
235+
if state_dict:
236+
key = prefix + "weight"
237+
new_key = prefix + "embedding.weight"
238+
state_dict[new_key] = state_dict[key]
239+
del state_dict[key]
239240

240241
def fusion_params(self) -> List[str]:
241242
"""

torchtune/modules/peft/_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def get_merged_lora_ckpt(
259259

260260
# Otherwise it is just vanilla LoRA
261261
else:
262-
print(f"module is {module}")
263262
state_dict[f"{module}.weight"] += (
264263
(alpha / rank) * lora_b_weight @ lora_a_weight
265264
)

0 commit comments

Comments
 (0)