Skip to content

Commit eca5265

Browse files
authored
Use new get_model_state_dict api for save_pretrained peft model (#629)
1 parent 48ba680 commit eca5265

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

src/llama_recipes/model_checkpointing/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from llama_recipes.model_checkpointing.checkpoint_handler import (
55
load_model_checkpoint,
66
save_model_checkpoint,
7+
save_peft_checkpoint,
78
load_optimizer_checkpoint,
89
save_optimizer_checkpoint,
910
save_model_and_optimizer_sharded,

src/llama_recipes/model_checkpointing/checkpoint_handler.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828

29+
from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
2930
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
3031
import torch.distributed._shard.checkpoint as dist_cp
3132
import torch.distributed as dist
@@ -264,4 +265,12 @@ def load_sharded_model_single_gpu(model,model_path):
264265
model.load_state_dict(state_dict["model"])
265266

266267
print(f"Sharded state checkpoint loaded from {model_path}")
267-
return model
268+
return model
269+
270+
def save_peft_checkpoint(model, model_path):
271+
"""save_pretrained peft model"""
272+
273+
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
274+
275+
state_dict = get_model_state_dict(model, options=options)
276+
model.save_pretrained(model_path, state_dict=state_dict)

src/llama_recipes/utils/train_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import json
2121

2222

23-
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
23+
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint
2424
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
2525
from llama_recipes.utils.memory_utils import MemoryTrace
2626
from accelerate.utils import is_xpu_available, is_ccl_available
@@ -235,7 +235,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
235235
print(f"we are about to save the PEFT modules")
236236
else:
237237
print(f"we are about to save the PEFT modules")
238-
model.save_pretrained(train_config.output_dir)
238+
save_peft_checkpoint(model, train_config.output_dir)
239239
if train_config.enable_fsdp:
240240
if rank==0:
241241
print(f"PEFT modules are saved in {train_config.output_dir} directory")

0 commit comments

Comments
 (0)