Skip to content

Commit

Permalink
[AutoParallel] merge ckpt for inference (#9688)
Browse files Browse the repository at this point in the history
* add_readme

* add_unified_ckpt

* fix typo
  • Loading branch information
xuxinyi389 authored Dec 26, 2024
1 parent f7f3957 commit 8c04a15
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
29 changes: 29 additions & 0 deletions llm/auto_parallel/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,32 @@ cd ../../../slm/model_zoo/gpt-3/external_ops/ && python3 setup.py install && cd
参考训练脚本 **run_pretrain_auto.sh**,并开启 `to_static=1`,运行8卡 dp2mp2pp2的并行策略。

您可以参考 **run_pretrain_auto.sh**,按需求修改相关参数进行训练。

## 4.推理
推理流程包括:动态图推理 -> 动转静导出模型 -> 静态图推理。当前自动并行预训练保存的模型参数已支持用于动态图推理;动转静导出模型、静态图推理步骤请参考 [LLaMA 系列大模型运行文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/llama.md)

以动态图自动并行训练(dp2mp2pp2)为例。
- 分布式 ckpt 合并为单卡模型参数:

```python
import paddle
import paddle.distributed as dist

ckpt_path='/path/for/dist_ckpt'
# offload=1 将参数 offload 到 CPU,减少显存占用
merged_state_dict = dist.checkpoint.load_state_dict.load_merged_state_dict(ckpt_path, offload=1)
paddle.save(unsharded_state_dict, 'model_state.pdparams')

# 上述合并的模型参数格式为Paddle原生格式,如需转换为unified_param格式(safetensors),可继续执行如下代码:
python PaddleNLP/llm/auto_parallel/utils/convert_to_safetensors.py --input_path input_path [--output_path output_path] [--split_num split_num] [--offload offload]

# 参数介绍
--input_path: 输入的单卡模型参数路径
--output_path: 可选,输出模型参数路径,默认为'./temp'
--split_num: 可选,输出的模型参数分片数,默认为 1
--offload: 可选,是否将参数 offload 到 CPU,默认为 false
```

- 动态图推理

[大模型推理教程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/inference.md)
92 changes: 92 additions & 0 deletions llm/auto_parallel/utils/convert_to_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import paddle
from safetensors.numpy import save_file as safe_save_file

from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils.env import SAFE_WEIGHTS_INDEX_NAME


def convert_to_unified_ckpt(path: str, output_dir: str = "./tmp", split_num: int = 1, offload: bool = False):
"""
Convert a single card checkpoint to the unified format.
Args:
path (str): The path to the input checkpoint file.
output_dir (str, optional): The directory where the converted files will be saved. Defaults to ".".
split_num (int, optional): The number of shards to split the weights into output_dir. Defaults to 1.
offload (bool, optional): Whether to offload the weights to CPU memory before saving them. Defaults to False.
"""

def get_sub_state_dict(sub_keys, state_dict, weight_filename, index_weight_file, total_size):
"""
Get the sub-state dict and update the index weight file and total size.
Args:
sub_keys (list): A list of keys that belong to this sub-state dict.
state_dict (dict): The original state dict.
weight_filename (str): The filename of the corresponding weight file.
index_weight_file (dict): The dictionary containing the mapping from keys to their corresponding weight filenames.
total_size (int): The total size of the model so far.
"""
sub_state_dict = {key: state_dict[key].numpy() for key in sub_keys}
for key in sub_keys:
index_weight_file[key] = weight_filename
total_size += state_dict[key].numel().item() * dtype_byte_size(state_dict[key].dtype)
return sub_state_dict, total_size

if offload:
paddle.set_device("cpu")
state_dict = paddle.load(path)
all_keys = list(state_dict.keys())
split_size = len(all_keys) // split_num
extra_keys = len(all_keys) % split_num
index_weight_file = {}
total_size = 0

os.makedirs(output_dir, exist_ok=True)

index = 0
for rank in range(split_num):
current_size = split_size + (1 if rank < extra_keys else 0)
sub_keys = all_keys[index : index + current_size]
index += current_size
weight_filename = f"model-{rank+1:04d}-of-{split_num:04d}.safetensors"
sub_state_dict, total_size = get_sub_state_dict(
sub_keys, state_dict, weight_filename, index_weight_file, total_size
)
safe_save_file(sub_state_dict, os.path.join(output_dir, weight_filename))
with open(os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME), "w") as f:
json.dump({"metadata": {"total_size": total_size}, "weight_map": index_weight_file}, f, indent=4)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, required=True, help="The path to the input checkpoint file.")
parser.add_argument(
"--output_dir", type=str, default="./tmp", help="The directory where the converted files will be saved."
)
parser.add_argument(
"--split_num", type=int, default=1, help="The number of shards to split the weights into output_dir."
)
parser.add_argument(
"--offload", type=bool, help="Whether to offload the weights to CPU memory before saving them."
)
args = parser.parse_args()
convert_to_unified_ckpt(args.input_path, args.output_dir, args.split_num, args.offload)

0 comments on commit 8c04a15

Please sign in to comment.