Skip to content

Commit 56fe130

Browse files
authored
[hotfix] fix lora load (#6231)
* [hotfix] fix lora load * [hotfix] fix hp load * accelerate deepseek loading
1 parent f32861c commit 56fe130

File tree

10 files changed

+146
-38
lines changed

10 files changed

+146
-38
lines changed

applications/ColossalChat/examples/training_scripts/lora_finetune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def is_master():
257257
)
258258

259259
torch.set_default_dtype(torch.float)
260-
booster.load_model(model, args.pretrained)
260+
booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)
261261

262262
coordinator.print_on_master(
263263
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"

colossalai/booster/plugin/gemini_plugin.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def save_unsharded_model(
8585
if use_async:
8686
from colossalai.utils.safetensors import save
8787

88-
if id(model) not in self.pinned_state_dicts:
89-
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
88+
if hash(model) not in self.pinned_state_dicts:
89+
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
9090
for k, v in state_dict.items():
91-
self.pinned_state_dicts[id(model)][k].copy_(v)
92-
state_dict[k] = self.pinned_state_dicts[id(model)][k]
91+
self.pinned_state_dicts[hash(model)][k].copy_(v)
92+
state_dict[k] = self.pinned_state_dicts[hash(model)][k]
9393
writer = save(checkpoint, state_dict)
9494
self.async_writers.append(writer)
9595
else:
@@ -172,9 +172,9 @@ def save_sharded_model(
172172
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
173173

174174
if use_async and self.coordinator.is_master():
175-
if id(model) not in self.pinned_state_dicts:
176-
self.pinned_state_dicts[id(model)] = {}
177-
pinned_state_dicts = self.pinned_state_dicts[id(model)]
175+
if hash(model) not in self.pinned_state_dicts:
176+
self.pinned_state_dicts[hash(model)] = {}
177+
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
178178
else:
179179
pinned_state_dicts = None
180180
state_dict_shard = model.state_dict_shard(

colossalai/booster/plugin/hybrid_parallel_plugin.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
2727
from colossalai.cluster import ProcessGroupMesh
2828
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
29+
from colossalai.interface.model import PeftUnwrapMixin
2930
from colossalai.interface.optimizer import DistributedOptim
3031
from colossalai.logging import get_dist_logger
3132
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
@@ -225,7 +226,7 @@ def unwrap(self, unwrap_peft: bool = True):
225226
if isinstance(model, DDP):
226227
model = model.module
227228
if unwrap_peft and isinstance(model, PeftModel):
228-
model = model.get_base_model()
229+
model = PeftUnwrapMixin(model)
229230
return model
230231

231232
def _force_wait_all_gather(self):

colossalai/booster/plugin/torch_ddp_plugin.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
1313
from colossalai.cluster import DistCoordinator
1414
from colossalai.interface import ModelWrapper, OptimizerWrapper
15+
from colossalai.interface.model import PeftUnwrapMixin
1516
from colossalai.logging import get_dist_logger
1617
from colossalai.quantization import BnbQuantizationConfig, quantize_model
1718
from colossalai.utils import get_current_device
@@ -201,7 +202,7 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
201202
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
202203
model = self.module.module
203204
if unwrap_peft and isinstance(model, PeftModel):
204-
model = model.get_base_model()
205+
model = PeftUnwrapMixin(model)
205206
return model
206207

207208

colossalai/booster/plugin/torch_fsdp_plugin.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def save_unsharded_model(
103103
if use_async:
104104
from colossalai.utils.safetensors import save
105105

106-
if id(model) not in self.pinned_state_dicts:
107-
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
106+
if hash(model) not in self.pinned_state_dicts:
107+
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
108108
for k, v in full_model_state.items():
109-
self.pinned_state_dicts[id(model)][k].copy_(v)
110-
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
109+
self.pinned_state_dicts[hash(model)][k].copy_(v)
110+
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
111111
writer = save(checkpoint, full_model_state)
112112
self.async_writers.append(writer)
113113
else:
@@ -186,9 +186,9 @@ def save_sharded_model(
186186
state_dict = model.unwrap().state_dict()
187187

188188
if use_async and self.coordinator.is_master():
189-
if id(model) not in self.pinned_state_dicts:
190-
self.pinned_state_dicts[id(model)] = {}
191-
pinned_state_dicts = self.pinned_state_dicts[id(model)]
189+
if hash(model) not in self.pinned_state_dicts:
190+
self.pinned_state_dicts[hash(model)] = {}
191+
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
192192
else:
193193
pinned_state_dicts = None
194194
state_dict_shard = utils.shard_model_checkpoint(

colossalai/checkpoint_io/general_checkpoint_io.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def save_unsharded_model(
6060
if use_async:
6161
from colossalai.utils.safetensors import move_and_save
6262

63-
if id(model) not in self.pinned_state_dicts:
64-
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
65-
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
63+
if hash(model) not in self.pinned_state_dicts:
64+
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
65+
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
6666
self.async_writers.append(writer)
6767
else:
6868
# save the checkpoint
@@ -234,7 +234,7 @@ def save_sharded_model(
234234
index_file = CheckpointIndexFile(checkpoint_path)
235235

236236
if use_async:
237-
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
237+
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
238238
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
239239
sharded_state_dict=state_dict_shard,
240240
checkpoint=checkpoint_path,
@@ -243,7 +243,7 @@ def save_sharded_model(
243243
is_master=True,
244244
pinned_state_dict=pinned_state_dict,
245245
)
246-
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
246+
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
247247
self.async_writers.extend(writers)
248248
else:
249249
# Save shards of optimizer states.

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,9 @@ def save_sharded_model(
249249
# Only devices with tp_rank == 0 are responsible for model saving.
250250
control_saving = self.tp_rank == 0 and self.sp_rank == 0
251251
if control_saving and use_async:
252-
if id(model) not in self.pinned_state_dicts:
253-
self.pinned_state_dicts[id(model)] = {}
254-
pinned_state_dicts = self.pinned_state_dicts[id(model)]
252+
if hash(model) not in self.pinned_state_dicts:
253+
self.pinned_state_dicts[hash(model)] = {}
254+
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
255255
else:
256256
pinned_state_dicts = None
257257
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
@@ -789,11 +789,11 @@ def save_unsharded_model(
789789
if use_async:
790790
from colossalai.utils.safetensors import save
791791

792-
if id(model) not in self.pinned_state_dicts:
793-
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
792+
if hash(model) not in self.pinned_state_dicts:
793+
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
794794
for name, param in state_dict.items():
795-
self.pinned_state_dicts[id(model)][name].copy_(param)
796-
state_dict[name] = self.pinned_state_dicts[id(model)][name]
795+
self.pinned_state_dicts[hash(model)][name].copy_(param)
796+
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
797797
writer = save(path=checkpoint, state_dict=state_dict)
798798
self.async_writers.append(writer)
799799
else:
@@ -811,11 +811,11 @@ def save_unsharded_model(
811811
if use_async:
812812
from colossalai.utils.safetensors import save
813813

814-
if id(model) not in self.pinned_state_dicts:
815-
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
814+
if hash(model) not in self.pinned_state_dicts:
815+
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
816816
for name, param in complete_state_dict.items():
817-
self.pinned_state_dicts[id(model)][name].copy_(param)
818-
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
817+
self.pinned_state_dicts[hash(model)][name].copy_(param)
818+
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
819819
writer = save(path=checkpoint, state_dict=complete_state_dict)
820820
self.async_writers.append(writer)
821821
else:

colossalai/checkpoint_io/moe_checkpoint.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -701,15 +701,18 @@ def pre_save_model(self, model: nn.Module) -> dict:
701701
all_param = None
702702
# gather param from every ep rank
703703
# dist.all_gather(all_param, param, group=ep_group)
704-
dist.gather(param, all_param, group=ep_group)
704+
dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
705705
if ep_rank == 0:
706706
all_param = torch.cat(all_param, dim=0)
707707
state_dict[name] = all_param.cpu()
708708

709709
if self.pp_size > 1:
710710
if self.dp_rank == 0:
711-
out = [None for _ in range(self.pp_size)]
712-
dist.gather_object(state_dict, out, group=self.pp_group)
711+
if self.pp_rank == 0:
712+
out = [None for _ in range(self.pp_size)]
713+
else:
714+
out = None
715+
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
713716
if self.pp_rank == 0:
714717
new_state_dict = {}
715718
for o in out:

colossalai/checkpoint_io/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
2121

2222
from colossalai.accelerator import get_accelerator
23+
from colossalai.interface.model import PeftUnwrapMixin
2324
from colossalai.tensor.d_tensor import (
2425
is_customized_distributed_tensor,
2526
is_distributed_tensor,
@@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
554555
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
555556
except ImportError:
556557
return
558+
if isinstance(model, PeftUnwrapMixin):
559+
model = model.base_model
557560
if not isinstance(model, PreTrainedModel):
558561
return
559562

@@ -692,6 +695,9 @@ def load_state_dict_into_model(
692695
state_dict (dict): a dict containing parameters and
693696
persistent buffers.
694697
"""
698+
if isinstance(model, PeftUnwrapMixin):
699+
state_dict = model.patch_state_dict(state_dict)
700+
model = model.base_model
695701
if not isinstance(state_dict, Mapping):
696702
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
697703

colossalai/interface/model.py

+99-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,102 @@
1+
import re
2+
from typing import Dict, Set
3+
4+
import torch
15
import torch.nn as nn
2-
from peft import PeftModel
6+
from peft import PeftModel, PeftType
7+
8+
9+
def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
10+
config = model.peft_config[adapter_name]
11+
if config.peft_type != PeftType.LORA:
12+
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
13+
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
14+
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
15+
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
16+
bias = config.bias
17+
if bias == "none":
18+
to_return = {k for k in names if "lora_" in k}
19+
elif bias == "all":
20+
to_return = {k for k in names if "lora_" in k or "bias" in k}
21+
elif bias == "lora_only":
22+
to_return = set()
23+
for k in names:
24+
if "lora_" in k:
25+
to_return.add(k)
26+
bias_name = k.split("lora_")[0] + "bias"
27+
if bias_name in names:
28+
to_return.add(bias_name)
29+
else:
30+
raise NotImplementedError
31+
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
32+
if config.use_dora:
33+
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
34+
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
35+
# we want the state_dict format not to change, we remove the "weight" part.
36+
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
37+
38+
def renamed_dora_weights(k):
39+
if k.endswith(new_dora_suffix):
40+
k = k[:-7] # remove ".weight"
41+
return k
42+
43+
to_return = {renamed_dora_weights(k) for k in to_return}
44+
45+
to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
46+
return to_return
47+
48+
49+
class PeftUnwrapMixin:
50+
def __init__(self, peft_model: PeftModel):
51+
self.base_model = peft_model.get_base_model()
52+
# peft does not affect buffers
53+
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
54+
potential_lora_weights = set()
55+
for n in self.lora_layers:
56+
potential_lora_weights.add(f"{n}.weight")
57+
potential_lora_weights.add(f"{n}.bias")
58+
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
59+
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}
60+
61+
def named_parameters(self):
62+
for n, p in self.base_model.named_parameters():
63+
if n in self.lora_param_to_origin_param:
64+
n = self.lora_param_to_origin_param[n]
65+
yield n, p
66+
67+
def named_buffers(self):
68+
return self.base_model.named_buffers()
69+
70+
@property
71+
def _modules(self):
72+
return self.base_model._modules
73+
74+
@property
75+
def _non_persistent_buffers_set(self):
76+
return self.base_model._non_persistent_buffers_set
77+
78+
def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
79+
new_state_dict = {}
80+
for k, v in state_dict.items():
81+
if k in self.origin_param_to_lora_param:
82+
k = self.origin_param_to_lora_param[k]
83+
new_state_dict[k] = v
84+
return new_state_dict
85+
86+
def state_dict(self):
87+
state_dict = {}
88+
for k, v in self.base_model.state_dict().items():
89+
if k in self.lora_param_to_origin_param:
90+
k = self.lora_param_to_origin_param[k]
91+
state_dict[k] = v
92+
return state_dict
93+
94+
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
95+
state_dict = self.patch_state_dict(state_dict)
96+
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)
97+
98+
def __hash__(self):
99+
return hash(self.base_model)
3100

4101

5102
class ModelWrapper(nn.Module):
@@ -23,7 +120,7 @@ def unwrap(self, unwrap_peft: bool = True):
23120
else:
24121
model = self.module
25122
if unwrap_peft and isinstance(model, PeftModel):
26-
model = model.get_base_model()
123+
model = PeftUnwrapMixin(model)
27124
return model
28125

29126
def forward(self, *args, **kwargs):

0 commit comments

Comments
 (0)