Skip to content

Commit

Permalink
[auto parallel] add ut for parallel api (#69033)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Oct 30, 2024
1 parent cead7f5 commit 852b52c
Show file tree
Hide file tree
Showing 5 changed files with 611 additions and 2 deletions.
8 changes: 7 additions & 1 deletion test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_semi_auto_llama_acc_align
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=HYBRID")
endif()

if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_llama_save_load
Expand All @@ -130,3 +129,10 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_semi_auto_llama_save_load
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_parallel_api_with_llama MODULES test_parallel_api_with_llama ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_parallel_api_with_llama
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=HYBRID")
endif()
299 changes: 299 additions & 0 deletions test/auto_parallel/hybrid_strategy/parallel_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from functools import reduce

import numpy as np
from single_llama_model import LlamaForCausalLM, LlamaPretrainingCriterion

import paddle
import paddle.distributed as dist
from paddle import LazyGuard
from paddle.io import BatchSampler, DataLoader, Dataset


def is_pp_enable():
global_mesh = dist.auto_parallel.get_mesh()
return "pp" in global_mesh.dim_names


def get_mesh(pp_idx=None):
global_mesh = dist.auto_parallel.get_mesh()
assert global_mesh is not None, "global_mesh is not initialized!"
if pp_idx is None:
return global_mesh
if is_pp_enable():
mesh = global_mesh.get_mesh_with_dim("pp")[pp_idx]
return mesh
else:
return global_mesh


class Config:
vocab_size = 32000
hidden_size = 4096
intermediate_size = 11008
seq_length = 2048
num_hidden_layers = 2
num_attention_heads = 32
rms_norm_eps = 1e-6
sequence_parallel = False
use_lazy_init = False


class RandomDataset(Dataset):
def __init__(self, seq_len, num_samples=100):
super().__init__()
self.seq_len = seq_len
self.num_samples = num_samples

def __getitem__(self, index):
input = np.random.uniform(size=[self.seq_len]).astype("int64")
label = (np.random.uniform(size=[self.seq_len]) * 10).astype("int64")
return input, label

def __len__(self):
return self.num_samples


def create_optimizer(model, lr_scheduler):
decay_parameters = [
p.name
for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]

def apply_decay_param_fun(x):
return x in decay_parameters

# test global_clip in auto_parallel
if os.getenv("use_param_group") == "true":
param_group = {}
param_group["params"] = list(model.parameters())
param_group["weight_decay"] = 0.01
param_group["grad_clip"] = paddle.nn.ClipGradByGlobalNorm(1.0)
optimizer = paddle.optimizer.adamw.AdamW(
learning_rate=lr_scheduler,
apply_decay_param_fun=apply_decay_param_fun,
parameters=[param_group],
)
else:
optimizer = paddle.optimizer.adamw.AdamW(
learning_rate=lr_scheduler,
apply_decay_param_fun=apply_decay_param_fun,
parameters=model.parameters(),
weight_decay=0.01,
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
)
return optimizer


class TestParallelAPI:
def __init__(self):
self.config = Config()
self.dp = int(os.getenv("dp"))
self.mp = int(os.getenv("mp"))
self.pp = int(os.getenv("pp"))
if os.getenv("use_sp") == "true":
self.config.sequence_parallel = True
if os.getenv("use_lazy_init") == "true":
self.config.use_lazy_init = True
self.gradient_accumulation_steps = int(os.getenv("acc_step"))

self.amp = False
self.amp_dtype = "float16"
self.amp_level = "O1"
self.amp_master_grad = False
if os.getenv("amp") == "true":
self.amp = True
if os.getenv("amp_dtype") in ["float16", "bfloat16"]:
self.amp_dtype = os.getenv("amp_dtype")
if os.getenv("amp_level") in ["O0", "O1", "O2"]:
self.amp_level = os.getenv("amp_level")
if os.getenv("amp_master_grad") == "true":
self.amp_master_grad = True

self.init_dist_env()

def init_dist_env(self):
mesh_dims = [("dp", self.dp), ("pp", self.mp), ("mp", self.pp)]
if not mesh_dims:
mesh_dims = [("dp", 1)]
dim_names = [mesh_dim[0] for mesh_dim in mesh_dims]
mesh_shape = [mesh_dim[1] for mesh_dim in mesh_dims]
mesh_arr = np.arange(
0, reduce(lambda x, y: x * y, mesh_shape, 1)
).reshape(mesh_shape)
global_mesh = dist.ProcessMesh(mesh_arr, dim_names)
dist.auto_parallel.set_mesh(global_mesh)

def parallel_model(self, layer):
# call parallel api here
return layer

def run_llama(self, to_static=0, only_build=False):
if self.config.use_lazy_init:
with LazyGuard():
model = LlamaForCausalLM(self.config)
else:
model = LlamaForCausalLM(self.config)
criterion = LlamaPretrainingCriterion(self.config)

model = self.parallel_model(model)
criterion = self.parallel_model(criterion)
if self.config.use_lazy_init:
for param in model.parameters():
assert not param._is_initialized()
param.initialize()

if only_build:
return

lr_scheduler = paddle.optimizer.lr.LinearWarmup(
learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001
)
optimizer = create_optimizer(model, lr_scheduler)
if self.amp and not to_static:
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=self.amp_level,
dtype=self.amp_dtype,
master_grad=self.amp_master_grad,
)
optimizer = dist.shard_optimizer(optimizer)

train_dataset = RandomDataset(self.config.seq_length)
train_sampler = BatchSampler(
train_dataset,
batch_size=2,
shuffle=True,
drop_last=True,
)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=0,
)

if self.pp == 1:
meshes = [get_mesh(0)]
elif self.pp > 1:
meshes = [get_mesh(0), get_mesh(-1)]
else:
raise ValueError("pp should be greater or equal to 1")

dist_loader = dist.shard_dataloader(
dataloader=train_dataloader,
meshes=meshes,
shard_dims="dp",
)

global_step = 1
tr_loss = float(0)

if not to_static:
model.train()
scaler = None
if self.amp and self.amp_dtype == "float16":
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = dist.shard_scaler(scaler)

for epoch_idx in range(1):
for step, inputs in enumerate(dist_loader()):
input_ids, labels = inputs
custom_black_list = [
"reduce_sum",
"c_softmax_with_cross_entropy",
]
custom_white_list = []
if self.amp_level == "O2":
custom_white_list.extend(
["lookup_table", "lookup_table_v2"]
)
with paddle.amp.auto_cast(
self.amp,
custom_black_list=set(custom_black_list),
custom_white_list=set(custom_white_list),
level=self.amp_level,
dtype=self.amp_dtype,
):
logits = model(input_ids)
tr_loss_step = criterion(logits, labels)

if self.gradient_accumulation_steps > 1:
tr_loss_step /= self.gradient_accumulation_steps
if scaler is not None:
scaler.scale(tr_loss_step).backward()
else:
tr_loss_step.backward()
tr_loss += tr_loss_step

if global_step % self.gradient_accumulation_steps == 0:
logging.info(
f"step: {global_step // self.gradient_accumulation_steps} loss: {tr_loss.numpy()}"
)
if scaler is not None:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.clear_grad()
lr_scheduler.step()
tr_loss = 0

global_step += 1
if global_step // self.gradient_accumulation_steps >= 10:
break
else:
strategy = dist.Strategy()
if self.gradient_accumulation_steps > 1:
strategy.pipeline.accumulate_steps = (
self.gradient_accumulation_steps
)

if self.amp:
amp = strategy.amp
amp.enable = self.amp
amp.dtype = self.amp_dtype
amp.level = self.amp_level.lower()
if self.amp_master_grad:
amp.use_master_grad = True

dist_model = dist.to_static(
model,
dist_loader,
criterion,
optimizer,
strategy=strategy,
)

dist_model.train()
for step, inputs in enumerate(dist_loader()):
input_ids, labels = inputs
loss = dist_model(input_ids, labels)
logging.info(step, loss)

if step >= 10:
break

def run_test_cases(self):
self.run_llama(only_build=True)
# self.run_llama(to_static=0)
# self.run_llama(to_static=1)


if __name__ == '__main__':
TestParallelAPI().run_test_cases()
Loading

0 comments on commit 852b52c

Please sign in to comment.