Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NEW Feature] 新增基于hook的refined_recompute支持 #9396

Merged
merged 20 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.trl import SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
Expand Down Expand Up @@ -146,6 +147,10 @@ def main():
)

LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
model_args.lora,
)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm

# Config for model using dropout, such as GPT.
Expand Down
4 changes: 4 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device
Expand Down Expand Up @@ -413,6 +414,9 @@ def main():
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
# set all llm config
LlmMetaConfig.set_llm_config(config, training_args)
config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ class LlmMetaConfig:
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']",
),
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的配置信息会传到下游任务里面吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要 _set_unsavable_keys 吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要,这个zhonghui比较清楚用法,我看了一下实现可以满足需求。1是加了llmmetaclass,2是LlmMetaConfig.set_llm_config(model_config, training_args)
@DataClass
@llmmetaclass
@add_start_docstrings(TrainingArguments.doc)
class TrainingArguments(TrainingArguments):

# refined_recompute attributes
(
"refined_recompute",
str,
"",
"refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']",
),
("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"),
]

@classmethod
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def swiglu(x, y=None):
except:
flash_attention = None

from paddlenlp.transformers.refined_recompute import no_recompute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要叫no_recompute,感觉怪怪的

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要么改成skip_recompute也行

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recompute(func, xxxxx) vs no_recompute(func, xxxxxx)

from paddlenlp.transformers.ring_flash_attention import RingFlashAttention


Expand Down Expand Up @@ -174,6 +175,7 @@ def fusion_flash_attention(
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand Down Expand Up @@ -257,28 +259,34 @@ def fusion_flash_attention(
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)

if hasattr(F, "flashmask_attention"):
attn_output = F.flashmask_attention(
attn_output = no_recompute(
F.flashmask_attention,
query_states,
key_states,
value_states,
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
causal=True,
enable=skip_recompute,
)
else:
attn_output = F.flash_attention_with_sparse_mask(
attn_output = no_recompute(
F.flash_attention_with_sparse_mask,
query_states,
key_states,
value_states,
attn_mask_start_row_indices=attn_mask_startend_row_indices,
is_causal=True,
enable=skip_recompute,
)
else:
attn_output = F.scaled_dot_product_attention(
attn_output = no_recompute(
F.scaled_dot_product_attention,
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=query_states.shape[1] != 1,
enable=skip_recompute,
)
attn_weights = None

Expand Down
54 changes: 52 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.refined_recompute import (
RRColumnParallelLinear,
RRColumnSequenceParallelLinear,
RRRowParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
recompute,
)

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
Expand Down Expand Up @@ -216,6 +224,7 @@ def scaled_dot_product_attention(
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand All @@ -233,6 +242,7 @@ def scaled_dot_product_attention(
sequence_parallel,
reshard_layer,
npu_is_casual,
skip_recompute=skip_recompute,
)

# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
Expand Down Expand Up @@ -605,10 +615,24 @@ def __init__(self, config):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -719,9 +743,22 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -821,6 +858,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

self.attn_func = scaled_dot_product_attention

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if (
config.recompute
and not config.recompute_use_reentrant
and config.skip_recompute_ops.get("flash_attn", False)
):
self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True)

def _init_rope(self):
if (
hasattr(self.config, "rope_scaling")
Expand Down Expand Up @@ -1471,7 +1516,12 @@ def __init__(self, config: LlamaConfig):
)

self.layers = nn.LayerList(
[LlamaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)]
[
LlamaDecoderLayer(
create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers
)
for i in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config)

Expand Down
11 changes: 9 additions & 2 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
PipelineLayer,
SharedLayerDesc,
)
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
recompute,
)
from paddlenlp.utils.tools import get_env_device

from .modeling import (
Expand Down Expand Up @@ -371,7 +374,11 @@ def get_hcg():

for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers),
LayerDesc(
LlamaDecoderLayerPipe,
config=create_skip_config_for_refined_recompute(i, config),
layerwise_recompute=i not in self.no_recompute_layers,
),
f"llama.layers.{i}",
)
self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama")
Expand Down
48 changes: 45 additions & 3 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import

from paddlenlp.transformers.refined_recompute import (
RRColumnParallelLinear,
RRColumnSequenceParallelLinear,
RRRowParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
no_recompute,
recompute,
)

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:
Expand Down Expand Up @@ -154,9 +163,22 @@ def __init__(self, config):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if config.num_attention_heads % config.tensor_parallel_degree != 0:
Expand Down Expand Up @@ -227,12 +249,19 @@ def _attn(self, query, key, value, attention_mask=None):
return_softmax=self.config.attn_dropout_prob > 0.0,
)
else:
attn_output = F.scaled_dot_product_attention(
skip_recompute = (
self.config.recompute
and not self.config.recompute_use_reentrant
and self.config.skip_recompute_ops.get("flash_attn", False)
)
attn_output = no_recompute(
F.scaled_dot_product_attention,
query,
key,
value,
attn_mask=attention_mask,
is_causal=attention_mask is None,
enable=skip_recompute,
)
attn_weights = None

Expand Down Expand Up @@ -388,9 +417,22 @@ def __init__(self, config):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_ffn:
Expand Down Expand Up @@ -684,7 +726,7 @@ def __init__(self, config):
self.h = nn.LayerList(
[
QWenBlock(
config,
create_skip_config_for_refined_recompute(i, config),
)
for i in range(config.num_hidden_layers)
]
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/transformers/qwen/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
)

from .modeling import (
QWenBlock,
Expand Down Expand Up @@ -170,7 +173,7 @@ def get_hcg():
self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen")
for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(QWenBlockPipe, config=config),
LayerDesc(QWenBlockPipe, config=create_skip_config_for_refined_recompute(i, config)),
f"qwen.h.{i}",
)
self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f")
Expand Down
Loading