Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.64】PaddleNLP套件模型接入动转静训练功能 -part #7576

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions model_zoo/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ python -m paddle.distributed.launch --gpus "0" run_pretrain.py \
--weight_decay 1e-2 \
--adam_epsilon 1e-6 \
--warmup_steps 10000 \
--num_train_epochs 3 \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_train_epochs 参数脚本中未用到,设了会报错,所以删了

--input_dir data/ \
--output_dir pretrained_models/ \
--logging_steps 1 \
Expand All @@ -83,7 +82,6 @@ python -m paddle.distributed.launch --gpus "0" run_pretrain.py \
- `weight_decay` 表示AdamW优化器中使用的weight_decay的系数。
- `adam_epsilon` 表示AdamW优化器中使用的epsilon值。
- `warmup_steps` 表示动态学习率热启的step数。
- `num_train_epochs` 表示训练轮数。
- `input_dir` 表示输入数据的目录,该目录下所有文件名中包含training的文件将被作为训练数据。
- `output_dir` 表示模型的保存目录。
- `logging_steps` 表示日志打印间隔。
Expand Down Expand Up @@ -128,7 +126,6 @@ python -m paddle.distributed.launch --xpus "0" run_pretrain.py \
--weight_decay 1e-2 \
--adam_epsilon 1e-6 \
--warmup_steps 10000 \
--num_train_epochs 3 \
--input_dir data/ \
--output_dir pretrained_models/ \
--logging_steps 1 \
Expand All @@ -146,7 +143,6 @@ python -m paddle.distributed.launch --xpus "0" run_pretrain.py \
- `weight_decay` 表示AdamW优化器中使用的weight_decay的系数。
- `adam_epsilon` 表示AdamW优化器中使用的epsilon值。
- `warmup_steps` 表示动态学习率热启的step数。
- `num_train_epochs` 表示训练轮数。
- `input_dir` 表示输入数据的目录,该目录下所有文件名中包含training的文件将被作为训练数据。
- `output_dir` 表示模型的保存目录。
- `logging_steps` 表示日志打印间隔。
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@
if model is None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")

if self.args.to_static:
model = paddle.jit.to_static(model)
logger.info("Successfully to apply @to_static to the whole model.")

Check warning on line 270 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L269-L270

Added lines #L269 - L270 were not covered by tests
if self.args.should_save or self.args.should_save_model_state:
os.makedirs(self.args.output_dir, exist_ok=True)

Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,10 @@
default=False,
metadata={"help": "Whether to unify hybrid parallel checkpoint."},
)
to_static: Optional[bool] = field(
default=False,
metadata={"help": "Enable training under @to_static."},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -809,7 +813,7 @@

try:
self.use_auto_parallel = self.parallel_mode == "auto"
except:

Check warning on line 816 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L816

Added line #L816 was not covered by tests
pass

if paddle.distributed.get_world_size() > 1 and (
Expand Down Expand Up @@ -1050,34 +1054,34 @@
fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

elif self.use_auto_parallel:
world_size = paddle.distributed.get_world_size()
tensor_parallel_degree = max(self.tensor_parallel_degree, 1)

Check warning on line 1059 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1057-L1059

Added lines #L1057 - L1059 were not covered by tests
pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)

Check warning on line 1061 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1061

Added line #L1061 was not covered by tests
assert (
world_size % (tensor_parallel_degree * pipeline_parallel_degree) == 0
), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}."

Check warning on line 1065 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1065

Added line #L1065 was not covered by tests
self.data_parallel_degree = world_size // (tensor_parallel_degree * pipeline_parallel_degree)

if self.sharding_parallel_degree == -1:
if len(self.sharding) > 0:

Check warning on line 1069 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1067-L1069

Added lines #L1067 - L1069 were not covered by tests
self.sharding_parallel_degree = self.data_parallel_degree

sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
if sharding_parallel_degree == 1 and len(self.sharding) > 0:
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")

Check warning on line 1074 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1071-L1074

Added lines #L1071 - L1074 were not covered by tests
self.sharding = []

if ShardingOption.OFFLOAD in self.sharding:

Check warning on line 1077 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1076-L1077

Added lines #L1076 - L1077 were not covered by tests
warnings.warn("`offload` is not supported NOW!")

strategy = fleet.auto.Strategy()
if pipeline_parallel_degree > 1:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
for x in pipeline_parallel_config:
if len(x) > 0:

Check warning on line 1084 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1079-L1084

Added lines #L1079 - L1084 were not covered by tests
if x not in [
# "disable_p2p_cache_shape", # no need for auto_parallel
# "disable_partial_send_recv", # no implemenation for auto_parallel
Expand All @@ -1086,108 +1090,108 @@
# "enable_sharding_comm_overlap", # no implemenation for auto_parallel
# "enable_timer", # no implemenation for auto_parallel
]:
raise ValueError(

Check warning on line 1093 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1093

Added line #L1093 was not covered by tests
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
)

pipeline = strategy.pipeline
pipeline.enable = True
pipeline.accumulate_steps = self.gradient_accumulation_steps
pipeline.micro_batch_size = self.per_device_train_batch_size
pipeline.schedule_mode = "1F1B"

Check warning on line 1102 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1097-L1102

Added lines #L1097 - L1102 were not covered by tests
if self.amp_master_grad:
warnings.warn("`amp_master_grad` is not supported NOW in AutoParallel!")
self.amp_master_grad = False
logger.info(f"PP configs:{strategy.pipeline}, use master_grad: {self.amp_master_grad}")

Check warning on line 1107 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1104-L1107

Added lines #L1104 - L1107 were not covered by tests
if self.do_eval:
assert (
self.per_device_train_batch_size * self.gradient_accumulation_steps

Check warning on line 1110 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1109-L1110

Added lines #L1109 - L1110 were not covered by tests
== self.per_device_eval_batch_size
), (
"In pipeline model, the evaluation also shares same setting with training. "
"Please set per_device_eval_batch_size=per_device_train_batch_size * gradient_accumulation_steps."
)

if tensor_parallel_degree > 1:
mp_optimization = strategy.mp_optimization

if " " in self.tensor_parallel_config:
mp_config = set(self.tensor_parallel_config.split(" "))

Check warning on line 1121 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1117-L1121

Added lines #L1117 - L1121 were not covered by tests
else:
mp_config = set(self.tensor_parallel_config.split(","))

Check warning on line 1124 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1123-L1124

Added lines #L1123 - L1124 were not covered by tests
for x in mp_config:
if len(x) > 0:
if x not in [

Check warning on line 1127 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1126-L1127

Added lines #L1126 - L1127 were not covered by tests
"enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel
# "enable_mp_skip_c_identity",

Check warning on line 1129 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1129

Added line #L1129 was not covered by tests
# "enable_mp_fused_linear_param_grad_add",
]:
raise ValueError(
f"Found unknown tensor parallell config {x}, "

Check warning on line 1133 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1131-L1133

Added lines #L1131 - L1133 were not covered by tests
f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add"
)
try:
if "enable_mp_async_allreduce" in mp_config:
mp_optimization.allreduce_matmul_grad_overlapping = True

Check warning on line 1138 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1138

Added line #L1138 was not covered by tests
except:
warnings.warn(
"The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported "
"by current version of Paddle. Please try latest develop Paddle."
)

if sharding_parallel_degree > 1:
sharding = strategy.sharding

Check warning on line 1146 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1142-L1146

Added lines #L1142 - L1146 were not covered by tests
sharding.enable = True
sharding.degree = sharding_parallel_degree
if ShardingOption.SHARD_OP in self.sharding:
sharding.stage = 1
elif ShardingOption.SHARD_GRAD_OP in self.sharding:
sharding.stage = 2
elif ShardingOption.FULL_SHARD in self.sharding:
sharding.stage = 3

sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
for x in sharding_parallel_config:
if len(x) > 0:
if x not in [
# "enable_stage1_tensor_fusion",

Check warning on line 1160 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1151-L1160

Added lines #L1151 - L1160 were not covered by tests
# "enable_stage1_overlap",
# "enable_stage2_overlap",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, " f"accpet config is reduce_overlap."

Check warning on line 1165 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1162-L1165

Added lines #L1162 - L1165 were not covered by tests
)

if (
"enable_stage1_overlap" in sharding_parallel_config
or "enable_stage2_overlap" in sharding_parallel_config

Check warning on line 1170 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1170

Added line #L1170 was not covered by tests
):
sharding.reduce_overlap = True

if self.bf16 or self.fp16:

Check warning on line 1174 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1174

Added line #L1174 was not covered by tests
amp = strategy.amp
amp.enable = True
amp.dtype = "bfloat16" if self.bf16 else "float16"
amp.level = self.fp16_opt_level

Check warning on line 1178 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1178

Added line #L1178 was not covered by tests
amp.init_loss_scaling = self.scale_loss
amp.custom_black_list = self.amp_custom_black_list
amp.custom_white_list = self.amp_custom_white_list

if self.recompute:
recompute = strategy.recompute
recompute.enable = True

self.strategy = strategy

Check warning on line 1187 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1180-L1187

Added lines #L1180 - L1187 were not covered by tests
logger.info(self.strategy)
order = ["dp", "pp", "mp"]
degree = [self.data_parallel_degree, pipeline_parallel_degree, tensor_parallel_degree]
mesh_dims = list(filter(lambda x: x[1] > 1, list(zip(order, degree))))
if not mesh_dims:
mesh_dims = [("dp", 1)]
fleet.auto.create_mesh(mesh_dims)

Check warning on line 1194 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1189-L1194

Added lines #L1189 - L1194 were not covered by tests
else:
world_size = paddle.distributed.get_world_size()
if world_size > 1:
Expand Down
Loading