Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ernie ci auto trainer error #9758

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions llm/auto_parallel/gpt-3/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ class ModelArguments:

hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."})
attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."})

use_fused_rope: Optional[bool] = field(
default=False,
metadata={"help": "Enable rope fusion or not."},
Expand Down Expand Up @@ -566,9 +565,6 @@ def fn(layer):
need_data=training_args.should_load_dataset,
)

# load_model_auto(model)
# model = shard_model(model)

trainer = PretrainingTrainer(
model=model,
criterion=criterion,
Expand Down
4 changes: 2 additions & 2 deletions llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,9 @@ def main():
config.recompute_granularity = model_args.recompute_granularity
config.virtual_pp_degree = model_args.virtual_pp_degree
config.sequence_parallel = training_args.sequence_parallel

config.fuse_sequence_parallel_allreduce = training_args.fuse_sequence_parallel_allreduce

config.use_fused_rope = model_args.use_fused_rope
config.no_recompute_layers = model_args.no_recompute_layers
config.pp_recompute_interval = model_args.pp_recompute_interval
Expand Down Expand Up @@ -600,7 +602,6 @@ def fn(layer):
tokenizer,
need_data=training_args.should_load_dataset,
)

trainer = PretrainingTrainer(
model=model,
criterion=criterion,
Expand All @@ -610,7 +611,6 @@ def fn(layer):
eval_dataset=eval_dataset if training_args.do_eval else None,
optimizers=(None, lr_scheduler),
tokenizer=tokenizer,
model_args=model_args,
)

checkpoint = None
Expand Down
101 changes: 53 additions & 48 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,17 @@
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.auto_parallel.intermediate.parallelize as parallelize

Check warning on line 23 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L23

Added line #L23 was not covered by tests
import paddle.nn as nn
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.intermediate.parallelize import (
parallelize_model,
parallelize_optimizer,
)
from tqdm.auto import tqdm

from paddlenlp.trainer import Trainer
from paddlenlp.transformers.model_utils import PretrainedModel

from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
from ..utils.log import logger
from .argparser import strtobool
from .auto_training_args import AutoTrainingArguments

Check warning on line 33 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L33

Added line #L33 was not covered by tests
from .trainer import SCALER_NAME, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME
from .trainer_callback import TrainerState
from .trainer_utils import ( # set_hyrbid_parallel_seed,
Expand Down Expand Up @@ -70,59 +67,67 @@
return loss

kwargs.update({"criterion": loss_func})

sequence_parallel = False
if kwargs.get("model_args", None) is not None:
model_args = kwargs.pop("model_args")
if hasattr(model_args, "sequence_parallel"):
sequence_parallel = model_args.sequence_parallel

self.auto_dist_config = kwargs.pop("auto_dist_config", None)
model = kwargs.get("model", None)
assert model is not None

Check warning on line 72 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L70-L72

Added lines #L70 - L72 were not covered by tests
if kwargs.get("args", None) is not None and kwargs["args"].use_intermediate_api:
model = kwargs.get("model", None)
assert model is not None
assert isinstance(model, PretrainedModel), f" AutoTrainer only support pretrained models,but got {model}"
for param in model.parameters():
assert not param._is_initialized(), "intermediate_api needs lazy init"

auto_dist_degree = {
"tensor_parallel": kwargs["args"].tensor_parallel_degree > 1,
"sequence_parallel": sequence_parallel,
"pipeline_parallel": kwargs["args"].pipeline_parallel_degree > 1,
"data_sharding_parallel": kwargs["args"].dataset_world_size > 1,
"sharding": kwargs["args"].sharding,
"sharding_mesh_dim": kwargs["args"].sharding_parallel_mesh_dimension,
}
auto_dist_config = model._generate_auto_dist_config(auto_dist_degree)
self.auto_dist_config = auto_dist_config

model = parallelize_model(
model,
config=self.auto_dist_config,
)

kwargs["model"] = model

if not parallelize.has_parallelized_model:
model, self.auto_dist_config = self.parallel_model(model, kwargs["args"])
kwargs["model"] = model

Check warning on line 76 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L74-L76

Added lines #L74 - L76 were not covered by tests
else:
assert kwargs.get(

Check warning on line 78 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L78

Added line #L78 was not covered by tests
"auto_dist_config", None
), "if use AutoTrainer.parallel_model , auto_dist_config obtained from parallel_model should be passed to AutoTrainer "
self.auto_dist_config = kwargs.pop("auto_dist_config")

Check warning on line 81 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L81

Added line #L81 was not covered by tests
model = kwargs["model"]
for param in model.parameters():
if not param._is_initialized():
try:
param.initialize()
except Exception as e:
# NOTE(zhangwl):maybe param is not initialized and param init_func is set in later.user need set_init_func before auto_trainer
logger.warning(
f"AutoTrainer requires all parameters to be initialized when auto_trainer init, but failed to initialize parameter {param.name} {param}.\n"
+ "Please check param init func.\n"
+ f"The original exception message is:\n{str(e)}"
)
# NOTE(zhangwl):in pipeline mode , param my be initialized before while delte init_func ,but param is still not is_initialized
if not param._is_initialized() and param._init_func is not None:
param.initialize()

Check warning on line 86 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L85-L86

Added lines #L85 - L86 were not covered by tests
Copy link
Contributor

@jeff41404 jeff41404 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if param._init_func is not None, should use param._init_func() or model. _init_weights(Layer) ?

kwargs["model"] = model

super().__init__(*args, **kwargs)
assert self.args.enable_auto_parallel

self.global_mesh = fleet.auto.get_mesh()
self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group()
self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]

@classmethod
def parallel_model(cls, model, training_args: AutoTrainingArguments):

Check warning on line 96 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L95-L96

Added lines #L95 - L96 were not covered by tests
"""
Parallelize the model from a single card version to a distributed version.
Args:
model (paddle.nn.Layer): the model to be parallelized.
training_args (AutoTrainingArguments) : Training arguments which contain distributed information
Returns:
the model after parallelize and config conatins distributed strategy
"""
if not training_args.use_intermediate_api:
return model, None
assert model is not None
for param in model.parameters():
if param._is_initialized():
logger.warning(

Check warning on line 110 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L105-L110

Added lines #L105 - L110 were not covered by tests
"intermediate_api needs lazy init because if param init before parallelize_model ,"
+ " param will be allocated the full amount of memory"
+ " We recommend reallocating memory after paralleliz-model to reduce the peak of memory allocation"
)

auto_dist_degree = {

Check warning on line 116 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L116

Added line #L116 was not covered by tests
"tensor_parallel": training_args.tensor_parallel_degree > 1,
"sequence_parallel": training_args.sequence_parallel,
"pipeline_parallel": training_args.pipeline_parallel_degree > 1,
"data_sharding_parallel": training_args.dataset_world_size > 1,
"sharding": training_args.sharding,
"sharding_mesh_dim": training_args.sharding_parallel_mesh_dimension,
}
auto_dist_config = model._generate_auto_dist_config(auto_dist_degree)
model = parallelize.parallelize_model(

Check warning on line 125 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L124-L125

Added lines #L124 - L125 were not covered by tests
model,
config=auto_dist_config,
)
return model, auto_dist_config

Check warning on line 129 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L129

Added line #L129 was not covered by tests

def _nested_gather(self, tensors):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
Expand Down Expand Up @@ -170,7 +175,7 @@

if self.args.use_intermediate_api:
assert self.auto_dist_config is not None
self.optimizer = parallelize_optimizer(
self.optimizer = parallelize.parallelize_optimizer(

Check warning on line 178 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L178

Added line #L178 was not covered by tests
self.optimizer,
config=self.auto_dist_config,
)
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@
# repeat k/v heads if n_kv_heads < n_heads
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
paddle_version = float(paddle.__version__[:3])
if (paddle_version != 0.0) and (paddle_version <= 2.6):
if not self.config.use_flash_attention or (paddle_version != 0.0) and (paddle_version <= 2.6):

Check warning on line 522 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L522

Added line #L522 was not covered by tests
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/llama/modeling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@
# repeat k/v heads if n_kv_heads < n_heads
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
paddle_version = float(paddle.__version__[:3])
if (paddle_version != 0.0) and (paddle_version <= 2.6):
if not self.config.use_flash_attention or (paddle_version != 0.0) and (paddle_version <= 2.6):

Check warning on line 425 in paddlenlp/transformers/llama/modeling_network.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_network.py#L425

Added line #L425 was not covered by tests
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

Expand Down