Skip to content

Commit

Permalink
[AutoParallel] add pipeline.auto_parallel_profiler to auto_config (#7343
Browse files Browse the repository at this point in the history
)

* update

* update

* Fix init weight for llama modeling auto

* update

* add support for Llama2

* recover codes

* remove training_args

* fix

* remove test config

* use guard

* change var name

* add import

* fix import error

* fix import error in run_pretrain_auto.py

---------

Co-authored-by: chenruibiao <chenruibiao@baidu.com>
  • Loading branch information
AndSonder and From00 authored Dec 15, 2023
1 parent b90cf05 commit 5106809
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
19 changes: 19 additions & 0 deletions llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import paddle
import paddle.distributed as dist
import paddle.distributed.auto_parallel as auto
from paddle.profiler.utils import job_schedule_profiler_range

from paddlenlp.trainer import (
PdArgumentParser,
Expand Down Expand Up @@ -100,6 +101,17 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)

job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
)
job_schedule_profiler_end: int = field(
default=-1,
metadata={"help": "The step to end job_schedule_profiler."},
)
parallel_mode: str = field(default="hybrid", metadata={"help": ""})

pipeline_schedule_mode: str = field(
default="1F1B", metadata={"help": "The pipeline schedule mode, support FThenB, 1F1B, VPP and Eager-1F1B."}
)
Expand Down Expand Up @@ -621,6 +633,10 @@ def loss_func(loss, outputs):
global_step_last_logged = 0
start_time_last_logged = time.time()
tr_loss = float(0)

job_schedule_profiler_start = training_args.job_schedule_profiler_start
job_schedule_profiler_end = training_args.job_schedule_profiler_end

local_batches = []
for epoch_idx in range(num_train_epochs):
for step, inputs in enumerate(train_dataloader):
Expand All @@ -630,6 +646,9 @@ def loss_func(loss, outputs):
elif pp_degree > 1:
local_batches = inputs

with job_schedule_profiler_range(step, job_schedule_profiler_start, job_schedule_profiler_end) as status:
engine.enable_job_schedule_profiler = status

for micro_batch in local_batches:
outs = engine.run(micro_batch, mode="train")

Expand Down
13 changes: 12 additions & 1 deletion model_zoo/gpt-3/ppfleetx/core/engine/auto_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.

import os
import sys
import numpy as np

import paddle
import paddle.base.core as core
import paddle.nn as nn
from paddle.distributed.fleet import auto
from paddle.profiler import SummaryView
from paddle.profiler.utils import job_schedule_profiler_range

try:
from ppfleetx.optims import build_lr_scheduler, build_optimizer
Expand Down Expand Up @@ -80,7 +82,10 @@ def __init__(self, configs, module=None, mode="train"):

# Distributed
self._pp_degree = configs["Distributed"]["pp_degree"]

pipeline_cfg = configs.Distributed.get("pipeline", {})
self._job_schedule_profiler_start = pipeline_cfg.get("job_schedule_profiler_start", -1)
self._job_schedule_profiler_end = pipeline_cfg.get("job_schedule_profiler_end", -1)

# engine configs
self._configs = configs["Engine"]

Expand Down Expand Up @@ -140,6 +145,9 @@ def __init__(self, configs, module=None, mode="train"):
self.memory_stats = configs.get("Profiler_auto", {}).get("memory_stats", False)
self.nvprof_start = configs.get("Profiler_auto", {}).get("nvprof_start", -1)
self.nvprof_end = configs.get("Profiler_auto", {}).get("nvprof_end", -1)

if (self._job_schedule_profiler_start != -1) and use_new_executor():
logger.info("Schedule Profiler start at step {} and end at step {}".format(self._job_schedule_profiler_start, self._job_schedule_profiler_end))

def _validate_batch(self, batch):
if self._pp_degree > 1 or self._accumulate_steps == 1:
Expand Down Expand Up @@ -174,6 +182,9 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
self._auto_engine.prepare(mode="train")

for step, batch in enumerate(train_data_loader):
with job_schedule_profiler_range(step, self._job_schedule_profiler_start, self._auto_engine.enable_job_schedule_profiler) as status:
self._auto_engine.enable_job_schedule_profiler = status

if epoch_index == self._load_recovery["epoch"]:
if step < self._load_recovery["step"]:
continue
Expand Down
2 changes: 2 additions & 0 deletions model_zoo/gpt-3/ppfleetx/utils/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def process_strategy(config):
pipeline.schedule_mode = pipeline_cfg.get("schedule_mode", "1F1B")
pipeline.micro_batch_size = config.Global.micro_batch_size
pipeline.accumulate_steps = accumulate_steps
pipeline.job_schedule_profiler_start = pipeline_cfg.get("job_schedule_profiler_start", -1)
pipeline.job_schedule_profiler_stop = pipeline_cfg.get("job_schedule_profiler_stop", -1)

elif accumulate_steps > 1:
# gradient merge config
Expand Down

0 comments on commit 5106809

Please sign in to comment.