Skip to content

Commit

Permalink
Add ordered save to avoid OOM (#9347)
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes authored Nov 1, 2024
1 parent 71dafa6 commit ba5c2ca
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,9 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并
async_save, enable asynchronous saving checkpoints to disk.
enable_all_options, enable all unified checkpoint optimization configs.
--ordered_save_group_size
选择同时轮流save checkpoint的进程数量。如果设置为0,则不使用轮流save checkpoint功能。
--skip_memory_metrics
是否跳过内存profiler检测。(可选,默认为True,跳过)
Whether or not to skip adding of memory profiler reports
Expand Down
27 changes: 27 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,13 @@ def _save_ckpt_func(state_dict, path, signal_path=None):

self._save_ckpt_func = _save_ckpt_func
self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load

if self.args.ordered_save_group_size > 0:
logger.info(f"using save in order, its group size is {self.args.ordered_save_group_size}")
assert not self.args.use_async_save, "Not support async save in ordered save"
assert self.args.tensor_parallel_degree % self.args.ordered_save_group_size == 0
self._save_ckpt_func = self._ordered_save

if self.args.use_async_save:
self._async_optimizer_saver = AsyncSaver()

Expand Down Expand Up @@ -2369,6 +2376,26 @@ def _filter_moe_no_sync_optimizer_params(self):
filter_optimzier_state_dict[op_k] = op_v
return filter_optimzier_state_dict

def _ordered_save(self, state_dict, save_path):
group_size = self.args.ordered_save_group_size
hcg = fleet.get_hybrid_communicate_group()
if hcg.get_sharding_parallel_world_size() > 1 or hcg.get_model_parallel_world_size() <= 1:
return paddle.save(state_dict, save_path)

mp_group = hcg.get_model_parallel_group()
ranks = list(mp_group.ranks)
n = len(ranks)

group_num = (n + group_size - 1) // group_size
groups = []
for i in range(group_num):
groups.append([ranks[j] for j in range(i, n, group_num)])

for group in groups:
if dist.get_rank() in group:
paddle.save(state_dict, save_path)
dist.barrier(mp_group)

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")
Expand Down
6 changes: 6 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,12 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether to use async_save instead of paddle.save."},
)
ordered_save_group_size: int = field(
default=0,
metadata={
"help": "Select ordered_save_group_size to save checkpoint in ordered. if ordered_save_group_size=0, not used ordered save"
},
)
skip_profile_timer: Optional[bool] = field(
default=True,
metadata={"help": "enable framework timer, will output timeline informatoin in logging and visualdl."},
Expand Down

0 comments on commit ba5c2ca

Please sign in to comment.