Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add paddle Bert kunlunxin case #172

Merged
merged 9 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions training/benchmarks/bert/paddle/schedulers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,31 @@

from .linear_warmup_poly_scheduler import LinearWarmupPolyDecayScheduler
from .linear_warmup_scheduler import LinearWarmUpScheduler
from paddlenlp.transformers import PolyDecayWithWarmup


def create_scheduler(optimizer, scheduler="poly"):
if config.warmup_proportion == 0:
warmup_steps = config.warmup_steps
warmup_start = config.start_warmup_step
else:
if config.warmup_steps == 0:
warmup_steps = int(config.max_steps * config.warmup_proportion)
warmup_start = 0

else:
warmup_steps = config.warmup_steps
warmup_start = config.start_warmup_step
if scheduler == "linear":
return LinearWarmUpScheduler(optimizer, warmup_steps, config.max_steps)

if scheduler == "poly":
# return LinearWarmupPolyDecayScheduler(optimizer, start_warmup_steps=warmup_start,
# warmup_steps=warmup_steps,
# total_steps=config.max_steps, end_learning_rate=0.0, degree=1.0)
return PolyDecayWithWarmup(learning_rate=config.learning_rate,
warmup=warmup_steps,
total_steps=config.max_steps,
lr_end=0.0,
power=1.0)
return LinearWarmupPolyDecayScheduler(
startup_warmup_steps=warmup_start,
warmup_steps=warmup_steps,
total_steps=config.max_steps,
base_lr=config.learning_rate,
end_lr=0.0,
degree=1.0)

# return PolyDecayWithWarmup(learning_rate=config.learning_rate,
# warmup=warmup_steps,
# total_steps=config.max_steps,
# lr_end=0.0,
# power=1.0)
raise ValueError(f"Not found scheduler {scheduler}.")
Original file line number Diff line number Diff line change
@@ -1,48 +1,35 @@
import torch
from .base import LRScheduler
import sys
from paddle.optimizer.lr import LRScheduler


class LinearWarmupPolyDecayScheduler(LRScheduler):
"""
Applies a warm up period to the learning rate.
"""

def __init__(self,
optimizer,
start_warmup_steps,
startup_warmup_steps,
warmup_steps,
total_steps,
end_learning_rate=0.0,
base_lr,
end_lr=0.0,
degree=1.0,
last_epoch=-1):
self.num_warmup_updates = warmup_steps
self.start_warmup_steps = start_warmup_steps
self.startup_warmup_steps = startup_warmup_steps
self.offset_step = int(startup_warmup_steps == 0)
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.end_learning_rate = end_learning_rate
self.base_lr = base_lr
self.end_lr = end_lr
self.degree = degree
super(LinearWarmupPolyDecayScheduler,
self).__init__(optimizer, last_epoch)

if self.last_epoch <= 0:
self.last_epoch = 0

def step(self, epoch=None):
param_group = self.optimizer.param_groups[0]
if 'step' in param_group:
self.last_epoch = param_group['step'] + 1
else:
self.last_epoch += 1

for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self).__init__(learning_rate=base_lr, last_epoch=last_epoch)

def get_lr(self):
mod_step = self.last_epoch - self.start_warmup_steps
if mod_step < self.num_warmup_updates:
progress = mod_step / self.num_warmup_updates
return [(base_lr * progress) for base_lr in self.base_lrs]
step = self.last_epoch + 1
mod_step = step - self.offset_step - self.startup_warmup_steps
if mod_step < self.warmup_steps:
p = mod_step / (self.warmup_steps + 1e-6)
lr = self.base_lr * p
else:
progress = min(self.last_epoch / self.total_steps, 1.0)
return [(base_lr - self.end_learning_rate) *
(1 - progress)**self.degree + self.end_learning_rate
for base_lr in self.base_lrs]
p = min(1, (step - self.offset_step) / self.total_steps)
lr = (self.base_lr - self.end_lr) * (1 -
p)**self.degree + self.end_lr
return lr
7 changes: 2 additions & 5 deletions training/benchmarks/bert/paddle/train/driver/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,10 @@ def setup_seeds(master_seed, epochs, device):

def barrier():
"""
Works as a temporary distributed barrier, currently pytorch
doesn't implement barrier for NCCL backend.
Calls all_reduce on dummy tensor and synchronizes with GPU.
Calls dist.barrier.
"""
if dist.is_initialized():
dist.all_reduce(paddle.to_tensor(1))
paddle.device.cuda.synchronize()
dist.barrier()


def get_rank(default=0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def create_grad_scaler():

def backward(step: int, loss, optimizer, **kwarg):
loss.backward()
optimizer.step()
optimizer.clear_grad()
need_update = step % config.gradient_accumulation_steps == 0
if need_update:
optimizer.step()
optimizer.clear_grad()
return
4 changes: 4 additions & 0 deletions training/benchmarks/bert/paddle/train/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def evaluate(self, trainer):
trainer.model.train()

if dist.is_initialized():
total_eval_mlm_acc = paddle.cast(total_eval_mlm_acc, 'float32')
total_eval_loss = paddle.cast(total_eval_loss, 'float32')
total_masked = paddle.cast(total_masked, 'float32')

dist.all_reduce(total_eval_mlm_acc, op=dist.ReduceOp.SUM)
dist.all_reduce(total_eval_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(total_masked, op=dist.ReduceOp.SUM)
Expand Down
73 changes: 73 additions & 0 deletions training/kunlunxin/bert-paddle/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

### 模型Checkpoint下载
[模型Checkpoint下载](../../benchmarks/bert/README.md#模型checkpoint下载)


### 测试数据集下载
[测试数据集下载](../../benchmarks/bert/README.md#测试数据集下载)


### Paddle版本运行指南

● bash环境变量:
```
export FLAGS_sync_nccl_allreduce=0
export FLAGS_fraction_of_gpu_memory_to_use=0.99
export FLAGS_call_stack_level=2
export FLAGS_use_fast_math=0
export FLAGS_enable_nvtx=1
export BKCL_CCIX_RING=1
export XPU_PADDLE_L3_SIZE=41943040
export XPU_PADDLE_FC_TRANS_A=1
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 #可用xpu索引
```

● 运行脚本:

在该路径目录下

```
python -u -m paddle.distributed.launch --xpus=${XPU_VISIBLE_DEVICES} run_pretraining.py \
--data_dir data_path \
--extern_config_dir config_path \
--extern_config_file config_file.py

```


example:
```
python -u -m paddle.distributed.launch --xpus=${XPU_VISIBLE_DEVICES} run_pretraining.py \
--data_dir /bert-data/train \
--extern_config_dir /home/FlagPerf/training/kunlunxin/bert-paddle/config \
--extern_config_file config_R300x1x8.py
```


### 昆仑芯XPU配置与运行信息参考
#### 环境配置
- ##### 硬件环境
- 机器型号: 昆仑芯AI加速器组R480-X8
- 加速卡型号: 昆仑芯AI加速卡R300
- 多机网络类型、带宽: InfiniBand,200Gb/s

- ##### 软件环境
- OS版本:Ubuntu 20.04
- OS kernel版本: 5.4.0-26-generic
- 加速卡驱动版本:4.0.25
- Docker镜像和版本:registry.baidubce.com/paddlepaddle/paddle:2.3.2
- 训练框架版本:paddlepaddle+f6161d1
- 依赖软件版本:pytorch-1.8.1



### 运行情况
| 训练资源 | 配置文件 | 运行时长(s) | 目标精度 | 收敛精度 | Steps数 | 性能(samples/s)|
| -------- | --------------- | ----------- | -------- | -------- | ------- | ---------------- |
| 单机8卡 | config_A100x1x8 | | 0.67 | 0.6709 | 11720 | |

### 许可证

本项目基于Apache 2.0 license。

本项目部分代码基于MLCommons https://github.com/mlcommons/training_results_v1.0/tree/master/NVIDIA/benchmarks/ 实现。
19 changes: 19 additions & 0 deletions training/kunlunxin/bert-paddle/config/config_R300x1x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
dist_backend = "xccl"

target_mlm_accuracy = 0.67
gradient_accumulation_steps = 7
max_steps = 50000
start_warmup_step = 0
warmup_proportion = 0
warmup_steps = 0

learning_rate = 4e-4
weight_decay_rate = 0.01
opt_lamb_beta_1 = 0.9
opt_lamb_beta_2 = 0.999
train_batch_size = 8
eval_batch_size = train_batch_size
max_samples_termination = 4500000
cache_eval_data = False

seed = 9031
9 changes: 9 additions & 0 deletions training/kunlunxin/bert-paddle/config/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
h5py==3.7.0
six==1.16.0
absl-py==1.2.0
paddle-bfloat==0.1.7
paddle2onnx==1.0.0
paddlefsl==1.1.0
paddlenlp==2.4.0
astor==0.8.1
torch==1.8.1