Skip to content

Commit

Permalink
[LLM] add llama1-13b pretrain
Browse files Browse the repository at this point in the history
[LLM] llama1-7b pretrain with callback
  • Loading branch information
LaiXinyi823 committed Sep 12, 2023
1 parent 45c4220 commit 6f1da92
Show file tree
Hide file tree
Showing 24 changed files with 574 additions and 616 deletions.
1 change: 0 additions & 1 deletion training/benchmarks/driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .base import Driver
from .callback_paddle import PaddleCallback
from .event import Event
from .log_event import LogEventManager
92 changes: 0 additions & 92 deletions training/benchmarks/driver/callback_paddle.py

This file was deleted.

174 changes: 92 additions & 82 deletions training/benchmarks/driver/dist_paddle.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,24 @@
import os
from contextlib import contextmanager
import random
import numpy as np

import paddle
import paddle.distributed as dist

from paddlenlp.trainer import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from paddlenlp.trainer.trainer_utils import IntervalStrategy

from .base import Driver
from .event import Event
from typing import Dict

def barrier():
if dist.is_initialized():
dist.barrier()

def set_seed(args):
if args.device == "cpu":
idx = 0
else:
idx = paddle.distributed.get_rank()
random.seed(args.seed + idx)
np.random.seed(args.seed + idx)
paddle.seed(args.seed + idx)


def get_rank(default=0):
"""
Gets distributed rank or returns zero if distributed is not initialized.
"""
if dist.is_initialized():
rank = dist.get_rank()
else:
rank = default
return rank


def get_world_size():
"""
Gets total number of distributed workers or returns one if distributed is
not initialized.
"""
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
return world_size


def main_proc_print(*args, **kwargs):
if is_main_process():
print(*args, **kwargs)


def init_dist_training_env(config):
if dist.get_world_size() <= 1:
config.device = paddle.device.get_device()
config.world_size = get_world_size()
else:
dist.init_parallel_env()
config.device = paddle.device.get_device()
config.world_size = get_world_size()
print("------------------------")
print("device numbers:", config.world_size)
print("the processing uses", config.device)
return


def global_batch_size(config):

return config.per_device_train_batch_size * config.world_size


@contextmanager
def sync_workers():
"""
Yields distributed rank and synchronizes all workers on exit.
"""
rank = get_rank()
yield rank
barrier()


def is_main_process():
if dist.is_initialized():
if "PADDLE_TRAINER_ID" in os.environ:
Expand All @@ -86,15 +28,83 @@ def is_main_process():

return True


def format_step(step):
if isinstance(step, str):
return step
s = ""
if len(step) > 0:
s += "Training Epoch: {} ".format(step[0])
if len(step) > 1:
s += "Training Iteration: {} ".format(step[1])
if len(step) > 2:
s += "Validation Iteration: {} ".format(step[2])
return s
class PaddleCallback(TrainerCallback):
def __init__(self, driver: Driver):
self.driver = driver

def on_init_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerState,
**kwargs
):
self.driver.event(Event.INIT_END)

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs
):
self.driver.event(Event.TRAIN_START)

def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs
):
self.driver.event(Event.TRAIN_END)

def on_epoch_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs
):
self.driver.event(Event.EPOCH_BEGIN, epoch=state.epoch)

def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs
):
self.driver.event(Event.EPOCH_END, epoch=state.epoch)

def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs
):
self.driver.event(Event.STEP_BEGIN, step=state.global_step + 1)

def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs
):
logs = kwargs["metrics"]
logs["global_step"] = state.global_step
self.driver.event(Event.EVALUATE, result=logs)

def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs
):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
self.driver.logger.log(Event.STEP_END, message=logs)
40 changes: 40 additions & 0 deletions training/benchmarks/llama1_13B/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
### 模型信息
#### 模型介绍
We introduce LLaMA, a collection of foundation language models ranging from 7B to 65B parameters. We train our models on trillions
of tokens, and show that it is possible to train
state-of-the-art models using publicly available datasets exclusively, without resorting
to proprietary and inaccessible datasets. In
particular, LLaMA-13B outperforms GPT-3
(175B) on most benchmarks, and LLaMA65B is competitive with the best models,
Chinchilla-70B and PaLM-540B. We release
all our models to the research community1
.

Please refer to this paper for a detailed description of LLaMA1:
[LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)

#### 模型代码来源
Paddle case代码来源:
https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/llama licensed under the Apache License, Version 2.0.


#### 数据集
##### 测试数据集下载地址
测试数据集中提供了处理好的100k条doc的训练样本:
```
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz
```

##### 预处理
> 无需预处理
#### 模型checkpoint
Paddle通过`model_name_or_path = "facebook/llama-13b"`自动加载 llama1-13b 模型参数。参数数:13B
Paddle case的 LLaMA 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)

### 框架与芯片支持情况
| | Pytorch |Paddle|TensorFlow2|
| ---- | ---- | ---- | ---- |
| Nvidia GPU |N/A ||N/A|
| 天数智芯 |N/A | N/A |N/A|
1 change: 1 addition & 0 deletions training/benchmarks/llama1_13B/paddle
40 changes: 40 additions & 0 deletions training/benchmarks/llama1_7B/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
### 模型信息
#### 模型介绍
We introduce LLaMA, a collection of foundation language models ranging from 7B to 65B parameters. We train our models on trillions
of tokens, and show that it is possible to train
state-of-the-art models using publicly available datasets exclusively, without resorting
to proprietary and inaccessible datasets. In
particular, LLaMA-13B outperforms GPT-3
(175B) on most benchmarks, and LLaMA65B is competitive with the best models,
Chinchilla-70B and PaLM-540B. We release
all our models to the research community1
.

Please refer to this paper for a detailed description of LLaMA1:
[LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)

#### 模型代码来源
Paddle case代码来源:
https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/llama licensed under the Apache License, Version 2.0.


#### 数据集
##### 测试数据集下载地址
测试数据集中提供了处理好的100k条doc的训练样本:
```
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz
```

##### 预处理
> 无需预处理
#### 模型checkpoint
Paddle通过`model_name_or_path = "facebook/llama-7b"`自动下载并加载 llama1-7b 的模型参数。参数数:7B。
Paddle case的 LLaMA 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)

### 框架与芯片支持情况
| | Pytorch |Paddle|TensorFlow2|
| ---- | ---- | ---- | ---- |
| Nvidia GPU |N/A ||N/A|
| 天数智芯 |N/A | N/A |N/A|
Loading

0 comments on commit 6f1da92

Please sign in to comment.