diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 3eb70dbab6ae..438aa8b3849b 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -65,6 +65,7 @@ register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig +from paddlenlp.transformers.refined_recompute import update_refined_recompute from paddlenlp.trl import SFTTrainer from paddlenlp.trl.llm_utils import ( ZeroPaddingIterDatasetCallback, @@ -146,6 +147,10 @@ def main(): ) LlmMetaConfig.set_llm_config(model_config, training_args) + model_config.refined_recompute = update_refined_recompute( + training_args.refined_recompute, + model_args.lora, + ) model_config.use_fast_layer_norm = model_args.use_fast_layer_norm # Config for model using dropout, such as GPT. diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 26d94cde7ab0..76cdbdbcb7ac 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -44,6 +44,7 @@ register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass +from paddlenlp.transformers.refined_recompute import update_refined_recompute from paddlenlp.utils.batch_sampler import DistributedBatchSampler from paddlenlp.utils.log import logger from paddlenlp.utils.tools import get_env_device @@ -413,6 +414,9 @@ def main(): config = AutoConfig.from_pretrained(model_args.model_name_or_path) # set all llm config LlmMetaConfig.set_llm_config(config, training_args) + config.refined_recompute = update_refined_recompute( + training_args.refined_recompute, + ) config.use_fast_layer_norm = model_args.use_fast_layer_norm config.seq_length = data_args.max_seq_length diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index c1fd2e0c530f..eb7fb6060a50 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -268,6 +268,14 @@ class LlmMetaConfig: "Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']", ), ("recompute_use_reentrant", bool, False, "recompute_use_reentrant"), + # refined_recompute attributes + ( + "refined_recompute", + str, + "", + "refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']", + ), + ("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"), ] @classmethod diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 43af1dd99e5c..164151a8c807 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -51,6 +51,7 @@ def swiglu(x, y=None): except: flash_attention = None +from paddlenlp.transformers.refined_recompute import no_recompute from paddlenlp.transformers.ring_flash_attention import RingFlashAttention @@ -174,6 +175,7 @@ def fusion_flash_attention( sequence_parallel=False, reshard_layer=None, npu_is_casual=False, + skip_recompute=False, ): bsz, q_len, num_heads, head_dim = query_states.shape _, kv_seq_len, _, _ = value_states.shape @@ -257,28 +259,34 @@ def fusion_flash_attention( attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1) if hasattr(F, "flashmask_attention"): - attn_output = F.flashmask_attention( + attn_output = no_recompute( + F.flashmask_attention, query_states, key_states, value_states, startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1), causal=True, + enable=skip_recompute, ) else: - attn_output = F.flash_attention_with_sparse_mask( + attn_output = no_recompute( + F.flash_attention_with_sparse_mask, query_states, key_states, value_states, attn_mask_start_row_indices=attn_mask_startend_row_indices, is_causal=True, + enable=skip_recompute, ) else: - attn_output = F.scaled_dot_product_attention( + attn_output = no_recompute( + F.scaled_dot_product_attention, query_states, key_states, value_states, attn_mask=attention_mask, is_causal=query_states.shape[1] != 1, + enable=skip_recompute, ) attn_weights = None diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 17a7517e6f05..099abbbff68c 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -29,7 +29,15 @@ from paddle.autograd import PyLayer from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker -from paddle.distributed.fleet.utils import recompute + +from paddlenlp.transformers.refined_recompute import ( + RRColumnParallelLinear, + RRColumnSequenceParallelLinear, + RRRowParallelLinear, + RRRowSequenceParallelLinear, + create_skip_config_for_refined_recompute, + recompute, +) try: from paddle.incubate.nn.functional import fused_rotary_position_embedding @@ -216,6 +224,7 @@ def scaled_dot_product_attention( sequence_parallel=False, reshard_layer=None, npu_is_casual=False, + skip_recompute=False, ): bsz, q_len, num_heads, head_dim = query_states.shape _, kv_seq_len, _, _ = value_states.shape @@ -233,6 +242,7 @@ def scaled_dot_product_attention( sequence_parallel, reshard_layer, npu_is_casual, + skip_recompute=skip_recompute, ) # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] @@ -605,10 +615,24 @@ def __init__(self, config): if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if config.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if config.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowParallelLinear + if config.tensor_parallel_degree > 1: if config.fuse_attention_ffn: self.gate_up_fused_proj = ColumnParallelLinear( @@ -719,9 +743,22 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if config.skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if config.skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: if self.fuse_attention_qkv: @@ -821,6 +858,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.attn_func = scaled_dot_product_attention + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if ( + config.recompute + and not config.recompute_use_reentrant + and config.skip_recompute_ops.get("flash_attn", False) + ): + self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) + def _init_rope(self): if ( hasattr(self.config, "rope_scaling") @@ -1471,7 +1516,12 @@ def __init__(self, config: LlamaConfig): ) self.layers = nn.LayerList( - [LlamaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)] + [ + LlamaDecoderLayer( + create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers + ) + for i in range(config.num_hidden_layers) + ] ) self.norm = LlamaRMSNorm(config) diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index f4598aec1014..1ec4c027a72a 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -22,9 +22,12 @@ PipelineLayer, SharedLayerDesc, ) -from paddle.distributed.fleet.utils import recompute from paddlenlp.transformers.model_utils import PipelinePretrainedModel +from paddlenlp.transformers.refined_recompute import ( + create_skip_config_for_refined_recompute, + recompute, +) from paddlenlp.utils.tools import get_env_device from .modeling import ( @@ -371,7 +374,11 @@ def get_hcg(): for i in range(config.num_hidden_layers): self.add_sequential_layer( - LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers), + LayerDesc( + LlamaDecoderLayerPipe, + config=create_skip_config_for_refined_recompute(i, config), + layerwise_recompute=i not in self.no_recompute_layers, + ), f"llama.layers.{i}", ) self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama") diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index d97c2783382f..58e49e69d989 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -24,9 +24,18 @@ from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker -from paddle.distributed.fleet.utils import recompute from paddle.utils import try_import +from paddlenlp.transformers.refined_recompute import ( + RRColumnParallelLinear, + RRColumnSequenceParallelLinear, + RRRowParallelLinear, + RRRowSequenceParallelLinear, + create_skip_config_for_refined_recompute, + no_recompute, + recompute, +) + try: from paddle.incubate.nn.functional import swiglu except ImportError: @@ -154,9 +163,22 @@ def __init__(self, config): if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if config.skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if config.skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: if config.num_attention_heads % config.tensor_parallel_degree != 0: @@ -227,12 +249,19 @@ def _attn(self, query, key, value, attention_mask=None): return_softmax=self.config.attn_dropout_prob > 0.0, ) else: - attn_output = F.scaled_dot_product_attention( + skip_recompute = ( + self.config.recompute + and not self.config.recompute_use_reentrant + and self.config.skip_recompute_ops.get("flash_attn", False) + ) + attn_output = no_recompute( + F.scaled_dot_product_attention, query, key, value, attn_mask=attention_mask, is_causal=attention_mask is None, + enable=skip_recompute, ) attn_weights = None @@ -388,9 +417,22 @@ def __init__(self, config): if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if config.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if config.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: if self.fuse_attention_ffn: @@ -684,7 +726,7 @@ def __init__(self, config): self.h = nn.LayerList( [ QWenBlock( - config, + create_skip_config_for_refined_recompute(i, config), ) for i in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/qwen/modeling_pp.py b/paddlenlp/transformers/qwen/modeling_pp.py index 47357d6921e3..889ed60e5416 100644 --- a/paddlenlp/transformers/qwen/modeling_pp.py +++ b/paddlenlp/transformers/qwen/modeling_pp.py @@ -18,6 +18,9 @@ from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer from paddlenlp.transformers.model_utils import PipelinePretrainedModel +from paddlenlp.transformers.refined_recompute import ( + create_skip_config_for_refined_recompute, +) from .modeling import ( QWenBlock, @@ -170,7 +173,7 @@ def get_hcg(): self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen") for i in range(config.num_hidden_layers): self.add_sequential_layer( - LayerDesc(QWenBlockPipe, config=config), + LayerDesc(QWenBlockPipe, config=create_skip_config_for_refined_recompute(i, config)), f"qwen.h.{i}", ) self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f") diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index ced5c6ed7052..95061f55f15d 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -31,7 +31,15 @@ from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker -from paddle.distributed.fleet.utils import recompute + +from paddlenlp.transformers.refined_recompute import ( + RRColumnParallelLinear, + RRColumnSequenceParallelLinear, + RRRowParallelLinear, + RRRowSequenceParallelLinear, + create_skip_config_for_refined_recompute, + recompute, +) from .. import linear_utils from ..activations import ACT2FN @@ -160,6 +168,7 @@ def scaled_dot_product_attention( attn_mask_startend_row_indices=None, training=True, sequence_parallel=False, + skip_recompute=False, ): bsz, q_len, num_heads, head_dim = query_states.shape _, kv_seq_len, _, _ = value_states.shape @@ -177,6 +186,7 @@ def scaled_dot_product_attention( output_attentions, attn_mask_startend_row_indices=attn_mask_startend_row_indices, sequence_parallel=sequence_parallel, + skip_recompute=skip_recompute, ) else: # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] @@ -363,10 +373,24 @@ def __init__(self, config: Qwen2Config, is_shared=False): if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if config.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if config.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowParallelLinear + if config.tensor_parallel_degree > 1: self.gate_proj = ColumnParallelLinear( self.hidden_size, @@ -465,10 +489,24 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True): if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if config.skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if config.skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if config.skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowParallelLinear + if config.tensor_parallel_degree > 1: self.q_proj = ColumnParallelLinear(self.hidden_size, self.hidden_size, has_bias=True, gather_output=False) self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip @@ -488,6 +526,14 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True): self.attn_func = scaled_dot_product_attention + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if ( + config.recompute + and not config.recompute_use_reentrant + and config.skip_recompute_ops.get("flash_attn", False) + ): + self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) + def forward( self, hidden_states, @@ -901,7 +947,10 @@ def __init__(self, config: Qwen2Config): self.layers = nn.LayerList( [ - Qwen2DecoderLayer(config, layerwise_recompute=layer_idx not in self.no_recompute_layers) + Qwen2DecoderLayer( + create_skip_config_for_refined_recompute(layer_idx, config), + layerwise_recompute=layer_idx not in self.no_recompute_layers, + ) for layer_idx in range(config.num_hidden_layers) ] ) diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index ae0396ba6312..916baad328ce 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -23,7 +23,11 @@ PipelineLayer, SharedLayerDesc, ) -from paddle.distributed.fleet.utils import recompute + +from paddlenlp.transformers.refined_recompute import ( + create_skip_config_for_refined_recompute, + recompute, +) from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel @@ -294,7 +298,11 @@ def get_hcg(): for i in range(config.num_hidden_layers): self.add_sequential_layer( - LayerDesc(Qwen2DecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers), + LayerDesc( + Qwen2DecoderLayerPipe, + config=create_skip_config_for_refined_recompute(i, config), + layerwise_recompute=i not in self.no_recompute_layers, + ), f"qwen2.layers.{i}", ) self.add_sequential_layer(LayerDesc(Qwen2RMSNormPipe, config=config), "qwen2") diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py new file mode 100644 index 000000000000..0884e4d688df --- /dev/null +++ b/paddlenlp/transformers/refined_recompute.py @@ -0,0 +1,825 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import copy +import inspect +import queue +import uuid +import weakref +from copy import deepcopy + +import paddle +import paddle.autograd +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, +) +from paddle.distributed.fleet.recompute.recompute import check_recompute_necessary +from paddle.distributed.fleet.recompute.recompute import recompute as original_recompute +from paddle.distributed.fleet.recompute.recompute import switch_rng_state_tracker + +try: + from paddle.distributed.fleet.utils import sequence_parallel_utils +except ImportError: + sequence_parallel_utils = None +from paddle.distributed.fleet.layers.mpu import mp_layers, mp_ops + +from paddlenlp.transformers.linear_utils import ( + ColumnParallelLinear, + ColumnSequenceParallelLinear, + RowParallelLinear, + RowSequenceParallelLinear, +) +from paddlenlp.utils.log import logger + +try: + from paddle.base import core, framework +except ImportError: + from paddle.fluid import core, framework + +__all__ = [ + "NoRecomputeContext", + "no_recompute", + "recompute", + "get_global_rr_queue_dict", + "update_refined_recompute", + "RRColumnSequenceParallelLinear", + "RRRowSequenceParallelLinear", + "RRColumnParallelLinear", + "RRRowParallelLinear", +] +_in_no_recompute = False +global_rr_queue_dict = {} +recompute_suffix = "@recompute" +_recompute_id = -1 + +# https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_7th/%E3%80%90Hackathon%207th%E3%80%91FundableProject%E4%BB%BB%E5%8A%A1%E5%90%88%E9%9B%86.md#%E4%B9%9Dpaddle-lod-%E9%80%80%E5%9C%BA%E6%B8%85%E7%90%86 +if hasattr(core.VarDesc.VarType, "DENSE_TENSOR"): + DENSE_TENSOR = core.VarDesc.VarType.DENSE_TENSOR +else: + DENSE_TENSOR = core.VarDesc.VarType.LOD_TENSOR + + +def set_recompute_id(value=-1): + """switch recompute id to the given value""" + global _recompute_id + _recompute_id = str(value) + + +def get_recompute_id(): + """get current recompute id""" + global _recompute_id + return str(_recompute_id) + + +@contextlib.contextmanager +def switch_recompute_id_ctx(value=-1): + """switch recompute id to the given value within the context""" + raw_recompute_id = get_recompute_id() + set_recompute_id(value) + yield + set_recompute_id(raw_recompute_id) + + +def in_no_recompute_ctx(): + """check if in no recompute context""" + global _in_no_recompute + return _in_no_recompute + + +def set_no_recompute(value=True): + """set whether in no recompute mode""" + global _in_no_recompute + _in_no_recompute = value + + +@contextlib.contextmanager +def switch_recompute_ctx(kwargs): + """switch recompute context to the given value within the context""" + for ts in kwargs.values(): + if paddle.is_tensor(ts) and not ts.name.endswith(recompute_suffix): + # 1. add recompute suffix to the tensor name + ts.name = ts.name + recompute_suffix + # 2. set in no recompute mode + set_no_recompute(True) + yield + for ts in kwargs.values(): + if paddle.is_tensor(ts) and ts.name.endswith(recompute_suffix): + # 3. remove recompute suffix from the tensor name + ts.name = ts.name[: -len(recompute_suffix)] + # 4. reset in no recompute mode + set_no_recompute(False) + + +def get_global_rr_queue_dict(): + """get global rr queue dict""" + global global_rr_queue_dict + return global_rr_queue_dict + + +# def print_global_rr_queue_info(name="pack"): +# queue_dict = get_global_rr_queue_dict() +# print("{:<10} {:<20} {:<10}".format("Action", "Queue Name", "Queue Size")) +# print("-" * 50) +# for k, v in queue_dict.items(): +# print("{:<10} {:<20} {:<10}".format(name, k, v.qsize())) +# print("=" * 50) + + +def parse_to_kwargs(function, *args, **kwargs): + """Parse the function arguments into a dictionary.""" + signature = inspect.signature(function) + bound_arguments = signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return bound_arguments.arguments + + +class _NoopSaveInputs(paddle.autograd.PyLayer): + """ + This layer does nothing but save all input tensors. + This is used to prevent the gradients of the inputs being computed. + """ + + @staticmethod + def forward(ctx, *args): + """This function does nothing but save all input tensors.""" + tensors = [o for o in args if isinstance(o, paddle.Tensor)] + ctx.save_for_backward(*tensors) + return paddle.empty((0,), dtype=tensors[0].dtype) + + @staticmethod + def backward(ctx, *args): + """Should not be called since we don't support backward on this graph.""" + raise AssertionError("Did not expect to backward on this graph") + + +def no_recompute(function, *args, **kwargs): + """ + Within a recompute context, do not recompute intermediate activations. + + Parameters: + function (paddle.nn.Layer): The layer or sequence of layers that describe a part of the model's + forward pass, whose intermediate activations will not be released. + *args (Tensor): Input tensors to the function. + **kwargs (Dict): Keyword arguments to the function. + + Returns: + The output of the function given the input tensors and keyword arguments. + """ + recompute_id_with_suffix = get_recompute_id() + # enable kwargs, in no recompute context, has grad + enable = kwargs.pop("enable", True) and recompute_id_with_suffix != "-1" and framework._dygraph_tracer()._has_grad + keys_ignore_to_save = kwargs.pop("keys_ignore_to_save", []) + if not enable: + return function(*args, **kwargs) + + if isinstance(function, paddle.nn.Layer): + func = function.forward + input_kwargs = parse_to_kwargs(func, *args, **kwargs) + elif isinstance(function, paddle.autograd.PyLayer): + func = function.apply + input_kwargs = parse_to_kwargs(function.forward, *args, **kwargs) + else: + func = function + input_kwargs = parse_to_kwargs(func, *args, **kwargs) + + is_first_fwd = recompute_id_with_suffix.endswith("@first") + recompute_id = recompute_id_with_suffix.split("@")[0] + + if is_first_fwd: + if recompute_id not in global_rr_queue_dict: + global_rr_queue_dict[recompute_id] = queue.Queue() + + with switch_recompute_ctx(input_kwargs): + result = func(*args, **kwargs) + + global_rr_queue_dict[recompute_id].put(result) + else: + tensor_list = [] + for key, val in input_kwargs.items(): + if key in keys_ignore_to_save: + continue + if val is not None and paddle.is_tensor(val): + tensor_list.append(val) + + if len(tensor_list) > 0: + _NoopSaveInputs.apply(*tensor_list) + + result = global_rr_queue_dict[recompute_id].get() + + if global_rr_queue_dict[recompute_id].empty(): + global_rr_queue_dict.pop(recompute_id) + return result + + +class NoRecomputeContext: + """ + A Context Manager class that do not recompute intermediate activations. + """ + + def __init__(self, enable=True, keys_ignore_to_save=[]): + """initialize the RefinedRecomputeFunction object.""" + self._enable = enable + self._keys_ignore_to_save = keys_ignore_to_save + + def __enter__(self): + """enter the context manager.""" + return self + + def __exit__(self, exc_type, exc_value, traceback): + """exit the context manager.""" + pass + + def __call__(self, function, *args, **kwargs): + """ + Within a recompute context, do not recompute intermediate activations. + + Parameters: + function (paddle.nn.Layer): The layer or sequence of layers that describe a part of the model's + forward pass, whose intermediate activations will not be released. + *args (Tensor): Input tensors to the function. + **kwargs (Dict): Keyword arguments to the function. + + Returns: + The output of the function given the input tensors and keyword arguments. + """ + kwargs["enable"] = self._enable + kwargs["keys_ignore_to_save"] = self._keys_ignore_to_save + return no_recompute(function, *args, **kwargs) + + +def share_buffer_to_tensor_or_param(inner_x): + """share buffer to tensor or param""" + if hasattr(inner_x, "main_grad"): + # donot deepcopy the `main_grad` to save memory + state = copy.deepcopy({k: v for k, v in inner_x.__dict__.items() if k != "main_grad"}) + tmp_tensor = framework.EagerParamBase( + shape=inner_x.shape, dtype=inner_x.dtype, name=inner_x.name + "cpy", **state + ) + setattr(tmp_tensor, "main_grad", inner_x.main_grad) + inner_x._unsafe_share_buffer_to(tmp_tensor) + else: + if inner_x.is_dist(): + # TODO(jeff41404): it seems better to use `tmp_tensor = core.eager.Tensor(inner_x)`, + # but other errors will be triggered during the current period, and can be modified after resolution + tmp_tensor = core.eager.Tensor( + inner_x.dtype, + inner_x.shape, + inner_x.name + "cpy", + DENSE_TENSOR, + inner_x.persistable, + inner_x.process_mesh, + inner_x.placements, + ) + else: + tmp_tensor = core.eager.Tensor( + inner_x.dtype, + inner_x.shape, + inner_x.name + "cpy", + DENSE_TENSOR, + inner_x.persistable, + ) + inner_x._unsafe_share_buffer_to(tmp_tensor) + tmp_tensor.stop_gradient = inner_x.stop_gradient + return tmp_tensor + + +def _recompute_without_reentrant(function, preserve_rng_state=True, *args, **kwargs): + """ + recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd. + """ + + if preserve_rng_state: + cur_device = paddle.get_device() + if "gpu:" in cur_device: + fw_cuda_rng_state = paddle.get_cuda_rng_state() + elif "cpu" in cur_device: + fw_cuda_rng_state = paddle.get_rng_state() + elif "xpu:" in cur_device: + fw_cuda_rng_state = paddle.get_rng_state() + elif cur_device.split(":")[0] in paddle.device.get_all_custom_device_type(): + fw_cuda_rng_state = paddle.get_rng_state(cur_device) + else: + raise RuntimeError(f"Recompute with RNG preserve is not support current device: {cur_device}.") + fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states_tracker() + tracer = framework._dygraph_tracer() + is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + amp_level = "O2" + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + amp_level = "O1" + + if tracer._amp_dtype == "float16": + amp_dtype = "float16" + elif tracer._amp_dtype in ("bfloat16", "float32"): + amp_dtype = "bfloat16" + + amp_white_list, amp_black_list = tracer._get_amp_op_list() + + class IntermediateHolder: + def __init__(self, name, shape, dtype) -> None: + self.name = name + self.shape = shape + self.dtype = dtype + + storage = weakref.WeakKeyDictionary() + holder_list = [] + # generate a unique id for the recompute context + recompute_id = str(int(uuid.uuid4())) + + def pack(x): + # [PACK] in no recompute context or input tensor no need recompute, return the input tensor directly + if x.persistable or (in_no_recompute_ctx() and not x.name.endswith(recompute_suffix)): + return share_buffer_to_tensor_or_param(x) + + # remove the recompute suffix + res = IntermediateHolder(x.name, x.shape, x.dtype) + holder_list.append(weakref.ref(res)) + return res + + def unpack(x): + # [UNPACK] in no recompute context or input tensor no need recompute, return the input tensor directly + if paddle.is_tensor(x): + return x + + unpack_counter = 0 + if len(storage) == 0: + + def inner_pack(inner_x): + if inner_x.persistable: + return + + nonlocal unpack_counter + unpack_counter += 1 + + if unpack_counter - 1 >= len(holder_list): + raise Exception( + "Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute." + "Please check your `keys_ignore_to_save`." + ) + + if holder_list[unpack_counter - 1]() is None: + return + if inner_x is None: + storage[holder_list[unpack_counter - 1]()] = None + return + + storage[holder_list[unpack_counter - 1]()] = share_buffer_to_tensor_or_param(inner_x) + return + + def inner_unpack(inner_x): + raise Exception("An unexpected backward called on a tensor!") + + rng_cxt_manager = ( + contextlib.nullcontext() + if not preserve_rng_state + else switch_rng_state_tracker(fw_cuda_rng_state, fwd_cuda_rng_state_tracker) + ) + with rng_cxt_manager: + with paddle.set_grad_enabled(True): + with paddle.amp.auto_cast( + enable=is_fw_autocast, + custom_white_list=amp_white_list, + custom_black_list=amp_black_list, + level=amp_level, + dtype=amp_dtype, + ): + with switch_recompute_id_ctx(recompute_id + "@second"): + with paddle.autograd.saved_tensors_hooks(inner_pack, inner_unpack): + unused_outputs = function(*args, **kwargs) # noqa: F841 + + if x not in storage: + raise Exception( + "Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute." + ) + tensor = storage.pop(x) + assert x.shape == tensor.shape, ( + f"The shape:{x.shape} of the tensor saved by autograd is not " + f"consistent with the original tensor shape:{tensor.shape}! " + ) + assert x.dtype == tensor.dtype, ( + f"The dtype:{x.dtype} of the tensor saved by autograd is not" + f"consistent with the original tensor dtype:{tensor.dtype}! " + ) + return tensor + + with switch_recompute_id_ctx(recompute_id + "@first"): + with paddle.autograd.saved_tensors_hooks(pack, unpack): + outputs = function(*args, **kwargs) + + return outputs + + +def recompute(function, *args, **kwargs): + """ + recompute intermediate activations to save then memory. + + Parameters: + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs to the function. + **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params, + and the other contains 'preserve_rng_state' and 'use_reentrant'. the key-value pair of preserve_rng_state, + which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value + will be restored when the forward recalculation of backpropagation is performed, its default value is True. + the key-value pair of use_reentrant is used to indicate which implementation of recompute you will be used. + 'use_reentrant=True' means to use the PyLayer implementation of recompute, 'use_reentrant=False' means to + use the Hook implementation of recompute, its default value is True. + Returns: + Output of function on args. + """ + preserve = kwargs.pop("preserve_rng_state", True) + use_reentrant = kwargs.pop("use_reentrant", True) + if not use_reentrant: + if framework._dygraph_tracer()._has_grad: + check_args = list(args) + check_args.extend(list(kwargs.values())) + check_recompute_necessary(check_args) + return _recompute_without_reentrant(function, preserve, *args, **kwargs) + else: + kwargs["preserve_rng_state"] = preserve + kwargs["use_reentrant"] = use_reentrant + return original_recompute(function, *args, **kwargs) + + +def get_pp_vp_split_layers(layer_num, pp_size, vp_size, skip_recompute_num=-1): + """ + Get the selected layers to skip recompute. + + Args: + - skip_recompute_num (int, optional): The number of stages to skip recompute. If not provided or is negative + one, it means that all layers should be skipped. Default: -1. + + Returns: + - :obj:`set`: A set containing the selected layers to skip recompute. + + """ + + assert pp_size > 1, ( + "Only support pipeline parallel, " f"pp_size must be greater than 1, but got pp_size: {pp_size}" + ) + + if skip_recompute_num == -1: + # select all layers to skip recompute + skip_recompute_num = vp_size + + no_recompute_layer_num = [] + if skip_recompute_num == 0: + return set(no_recompute_layer_num) + + if vp_size == 1: + # If vp_size == 1, we can not select model chunk for pp, + # so if skip_recompute_num > 0, we select the all layers to skip recompute. + if skip_recompute_num > 0: + return set(range(layer_num)) + else: + return set() + + assert layer_num % (pp_size * vp_size) == 0, ( + "layer_num must be divisible by pp_size * vp_size," + f" but got layer_num: {layer_num}, pp_size: {pp_size}, vp_size: {vp_size}" + ) + + chunk_size = layer_num // (pp_size * vp_size) + chunk_list = [list(range(i * chunk_size, (i + 1) * chunk_size)) for i in range(pp_size * vp_size)] + + stage_chunk_list = [[] for _ in range(pp_size)] + for i in range(pp_size * vp_size): + stage_chunk_list[i % pp_size].append(chunk_list[i]) + + for i in range(pp_size): + no_recompute_layer_num.extend(stage_chunk_list[i][-skip_recompute_num:]) + + # Convert to 1D list + return set(sum(no_recompute_layer_num, [])) + + +def create_skip_config_for_refined_recompute(layer_idx, config): + """ + Creates a configuration for skipping recomputation based on the configuration file, + effective only at the specified layer index. + + Args: + layer_idx (int): The layer index used to check whether recomputation should be skipped. + config (dict): The configuration file of the input model. + + Returns: + dict: Returns an updated configuration file containing the following key-value pairs: + - skip_recompute_ops (dict): A dictionary with each operation's name and a boolean + indicating whether to skip recomputation, defaults to None. + - If the refined_recompute key does not exist or recompute is set to False, + the original configuration file is returned. + + """ + if not config.recompute: + return config + skip_config = dict() + config = deepcopy(config) + + try: + hcg = fleet.get_hybrid_communicate_group() + pp_size = max(hcg.get_pipe_parallel_world_size(), 1) + except: + pp_size = 1 + + for op_name, skip_num in config.refined_recompute.items(): + # is pp model + if pp_size > 1: + vp_size = max(config.virtual_pp_degree, 1) + layer_num = config.num_layers if hasattr(config, "num_layers") else config.num_hidden_layers + no_recompute_layers = get_pp_vp_split_layers(layer_num, pp_size, vp_size, skip_num) + if layer_idx in no_recompute_layers: + skip_config[op_name] = True + else: + skip_config[op_name] = False + else: + if skip_num == 0: # 0 means all recompute + skip_config[op_name] = False + elif skip_num < 0: # < 0 means all skip recompute + skip_config[op_name] = True + else: + if layer_idx < skip_num: # < the number of layers to skip recompute + skip_config[op_name] = True + else: + skip_config[op_name] = False + config.skip_recompute_ops = skip_config + return config + + +def update_refined_recompute(rr, lora=False): + """update refined recompute dict.""" + if rr == "": + return {} + else: + + rr_res = { + "mlp_row_ln": 0, + "attention_row_ln": 0, + "attention_column_ln": 0, + "mlp_column_ln": 0, + "flash_attn": 0, + } + ops = rr.split(",") + enable_rr = False + for op in ops: + if ":" not in op: + raise ValueError("Illegal refined_recompute input, please check.") + op_name, skip_num = op.split(":")[0], int(op.split(":")[1]) + if op_name not in rr_res: + raise ValueError(f"Refined recompute do not support {op_name}, please check.") + + if op_name in ["mlp_row_ln", "attention_row_ln", "attention_column_ln", "mlp_column_ln"]: + if lora: + logger.warning( + "Currently, LoRA does not support refined recompute " + f"for the `{op_name}` op. This refined recompute op will be ignored." + ) + continue + rr_res[op_name] = skip_num + if skip_num != 0: + enable_rr = True + + if not enable_rr: + rr_res = {} + return rr_res + + +class RRColumnParallelLinear(ColumnParallelLinear): + def forward(self, x): + # use inner api to process identity + def _overlap_linear(): + return mp_layers.InnerOverlapLinear.apply( + x, + self.weight, + self.bias, + self.fuse_matmul_bias, + self.mp_async_allreduce, + self.mp_skip_c_identity, + self.mp_fused_linear_param_grad_add, + self.model_parallel_group, + ) + + if self.mp_async_allreduce: + output_parallel = _overlap_linear() + else: + if self.is_mp: + input_parallel = mp_ops._c_identity( + x, + group=self.model_parallel_group, + skip_c_identity_dynamic=self.mp_skip_c_identity, + ) + else: + input_parallel = x + + def fwd(input_parallel): + return self.linear(input_parallel, self.weight, self.bias, name=self._name) + + output_parallel = no_recompute(fwd, input_parallel) + + if self.gather_output and self.is_mp: + output = mp_ops._c_concat(output_parallel, group=self.model_parallel_group) + else: + output = output_parallel + return output + + +class RRRowParallelLinear(RowParallelLinear): + def forward(self, x): + if self.input_is_parallel or (not self.is_mp): + input_parallel = x + else: + # split last dim + input_parallel = mp_ops._c_split(x, group=self.model_parallel_group) + + if self.is_mp: + if self.fuse_matmul_bias: + bias = mp_layers.MPScale.apply(self.bias, self.world_size) + else: + bias = None + + def fwd(input_parallel): + output_parallel = self.linear(input_parallel, self.weight, bias, name=self._name) + output_ = mp_ops._mp_allreduce( + output_parallel, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True, + skip_c_identity_dynamic=self.mp_skip_c_identity, + ) + return output_ + + output_ = no_recompute(fwd, input_parallel) + + if not self.fuse_matmul_bias and self.bias is not None: + output = output_ + self.bias + else: + output = output_ + else: + output = self.linear(input_parallel, self.weight, self.bias, name=self._name) + + return output + + +class RRColumnSequenceParallelLinear(ColumnSequenceParallelLinear): + """RRColumnSequenceParallelLinear""" + + def forward(self, x): + if self.mp_async_allreduce: + output = sequence_parallel_utils.SPInnerOverlapLinear.apply( + x, + self.weight, + self.bias, + self.fuse_matmul_bias, + self.recompute_allgather, + self.mp_fused_linear_param_grad_add, + self.model_parallel_group, + ) + else: + input_parallel = sequence_parallel_utils.AllGatherOp.apply(x) if self.is_mp else x + + def fwd(input_parallel): + output = self.linear(input_parallel, self.weight, self.bias, name=self._name) + return output + + # create a dummpy fwd function + output = no_recompute(fwd, input_parallel) + return output + + +class RRRowSequenceParallelLinear(RowSequenceParallelLinear): + """RRRowSequenceParallelLinear""" + + def forward(self, x): + input_parallel = x + if self.is_mp: + if self.mp_scale is not None: + bias = self.mp_scale(self.bias, self.world_size) + else: + bias = None + + def fwd(input_parallel): + output_parallel = self.linear(input_parallel, self.weight, bias, name=self._name) + output_ = sequence_parallel_utils.ReduceScatterOp.apply(output_parallel) + return output_ + + # create a dummpy fwd function + output_ = no_recompute(fwd, input_parallel) + # register_hook to all_reduce self.bias + if bias is None and self.bias is not None: + output = output_ + self.bias + else: + output = output_ + else: + output = self.linear(input_parallel, self.weight, self.bias, name=self._name) + return output + + +# if __name__ == "__main__": +# # test flashmask_attention +# paddle.seed(2024) +# from paddle.nn.functional.flash_attention import flashmask_attention + +# dtype = "float16" +# paddle.set_default_dtype(dtype) + +# in_weight_shape = (32, 3 * 2 * 32) +# linear1 = paddle.nn.Linear( +# in_weight_shape[0], +# in_weight_shape[-1], +# ) +# paddle.seed(2024) +# in_weight = paddle.create_parameter(shape=in_weight_shape, dtype=dtype, name="in_weight") +# in_weight.set_value(paddle.normal(0, 0.02, in_weight_shape)) +# in_weight.main_grad = paddle.normal(0, 0.02, in_weight.shape).cast("float32") +# linear1.weight.set_value(in_weight) +# in_bias = paddle.create_parameter(shape=(in_weight.shape[-1],), dtype=dtype, name="in_bias", is_bias=True) +# in_bias.main_grad = paddle.normal(0, 0.02, in_bias.shape).cast("float32") +# linear1.bias.set_value(in_bias) +# linear1.weight.main_grad = in_weight.main_grad +# linear1.bias.main_grad = in_bias.main_grad + +# out_weight_shape = (2 * 32, 32) +# out_weight = paddle.create_parameter(shape=out_weight_shape, dtype=dtype, name="out_weight") +# out_weight.set_value(paddle.normal(0, 0.02, out_weight_shape)) +# out_weight.main_grad = paddle.normal(0, 0.02, out_weight.shape).cast("float32") + +# class cus_multiply(paddle.autograd.PyLayer): +# @staticmethod +# def forward(ctx, a, b): +# y = paddle.multiply(a, b) +# ctx.save_for_backward(a, b) +# return y + +# @staticmethod +# def backward(ctx, dy): +# a, b = ctx.saved_tensor() +# grad_a = dy * a +# grad_b = dy * b +# return grad_a, grad_b + +# multiply = cus_multiply.apply + +# def fwd(x, startend_row_indices, enable=True): +# def fwd_linear(x): +# weight = multiply(linear1.weight, linear1.weight * 0.1) +# bias = multiply(linear1.bias, linear1.bias * 0.1) +# qkv = paddle.nn.functional.silu(paddle.nn.functional.linear(x, weight, bias)) +# q, k, v = paddle.chunk(qkv, 3, axis=-1) +# q = q.reshape([q.shape[0], q.shape[1], 2, q.shape[2] // 2]) +# k = k.reshape([k.shape[0], k.shape[1], 2, v.shape[2] // 2]) +# v = v.reshape([v.shape[0], k.shape[1], 2, v.shape[2] // 2]) +# return q, k, v + +# q, k, v = no_recompute(fwd_linear, x, enable=enable) + +# q, k, v = q * q, k * k, v * v +# out = no_recompute( +# flashmask_attention, +# q, +# k, +# v, +# startend_row_indices=startend_row_indices, +# causal=True, +# enable=enable, +# ) +# out = out.flatten(-2, -1) +# out = paddle.matmul(out, out_weight) +# return out + +# x = paddle.normal(0, 0.02, (1, 128, 32)) +# x.stop_gradient = False +# x_input = x +# startend_row_indices = paddle.randint(0, 128, (1, 2, 128, 1), dtype="int32") + +# enable = True +# # 第一层 +# o1 = recompute( +# fwd, +# x, +# startend_row_indices, +# enable=enable, +# ) +# # 第二层 +# o2 = recompute(fwd, o1 + x, startend_row_indices, enable=enable) +# # 第三层 +# o3 = recompute(fwd, o2 + x, startend_row_indices, enable=enable) + +# o3.sum().backward() +# print(x_input.grad.mean()) +# print(linear1.weight.grad.mean()) +# print(out_weight.grad.mean()) diff --git a/tests/fixtures/llm/finetune.yaml b/tests/fixtures/llm/finetune.yaml index abe9aad5d39e..eacd7b4740d8 100644 --- a/tests/fixtures/llm/finetune.yaml +++ b/tests/fixtures/llm/finetune.yaml @@ -23,6 +23,7 @@ finetune: eval_with_do_generation: false metric_for_best_model: "accuracy" recompute: true + refined_recompute: "flash_attn:-1" save_total_limit: 1 tensor_parallel_degree: 1 pipeline_parallel_degree: 1 diff --git a/tests/fixtures/llm/lora.yaml b/tests/fixtures/llm/lora.yaml index 5d75cb752682..b8a3fb730676 100644 --- a/tests/fixtures/llm/lora.yaml +++ b/tests/fixtures/llm/lora.yaml @@ -22,6 +22,7 @@ lora: eval_with_do_generation: false metric_for_best_model: "accuracy" recompute: true + refined_recompute: "flash_attn:-1" save_total_limit: 1 tensor_parallel_degree: 1 pipeline_parallel_degree: 1 diff --git a/tests/transformers/test_refined_recompute.py b/tests/transformers/test_refined_recompute.py new file mode 100644 index 000000000000..25a1cdee7bf5 --- /dev/null +++ b/tests/transformers/test_refined_recompute.py @@ -0,0 +1,559 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +environment_variables = { + "NVIDIA_TF32_OVERRIDE": "0", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", +} +for k, v in environment_variables.items(): + os.environ[k] = v +import unittest +from typing import Optional, Tuple + +import paddle +import paddle.device +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.distributed.fleet.recompute import recompute as original_recompute + +from paddlenlp.transformers.refined_recompute import no_recompute as rr_no_recompute +from paddlenlp.transformers.refined_recompute import recompute as rr_recompute +from paddlenlp.utils.import_utils import is_paddle_cuda_available + +ACT2FN = { + "relu": F.relu, + "gelu": F.gelu, + "tanh": F.tanh, + "sigmoid": F.sigmoid, +} +dtype = paddle.float16 + + +class PyLayerMatmul(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, a, b): + ctx.save_for_backward(a, b) + return a @ b + + @staticmethod + def backward(ctx, dy): + a, b = ctx.saved_tensor() + if hasattr(a, "main_grad"): + a.main_grad.add_(paddle.ones_like(a.main_grad)) + if hasattr(b, "main_grad"): + b.main_grad.add_(paddle.ones_like(b.main_grad)) + grad_a = paddle.matmul(dy, b, transpose_y=True) + grad_b = paddle.matmul(a, dy, transpose_x=True) + return grad_a, grad_b + + +pylayer_matmul = PyLayerMatmul.apply + + +class BertConfig: + def __init__( + self, + vocab_size: int = 30522, + hidden_size: int = 768, + num_hidden_layers: int = 4, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + max_position_embeddings: int = 1024, + type_vocab_size: int = 2, + initializer_range: float = 0.2, + pad_token_id: int = 0, + pool_act: str = "tanh", + layer_norm_eps: float = 1e-12, + output_attentions: bool = False, + output_hidden_states: bool = False, + num_labels=2, + recompute=False, + use_rr_recompute=False, + recompute_use_reentrant=False, + **kwargs + ): + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.pool_act = pool_act + self.layer_norm_eps = layer_norm_eps + self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.num_labels = num_labels + self.recompute = recompute + self.use_rr_recompute = use_rr_recompute + self.recompute_use_reentrant = recompute_use_reentrant + + +class BertEmbeddings(nn.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.register_buffer( + "position_ids", paddle.arange(config.max_position_embeddings, dtype="int64").reshape((1, -1)) + ) + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + input_shape = input_ids.shape + seq_length = input_ids.shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64) + + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + + reshape_fn = lambda x: x.reshape([0, 0, -1, self.attention_head_size]) + # compute q,k,v + query_layer = reshape_fn(self.query(hidden_states)) + key_layer = reshape_fn(self.key(hidden_states)) + value_layer = reshape_fn(self.value(hidden_states)) + + context_layer = rr_no_recompute( + F.scaled_dot_product_attention, + query=query_layer, + key=key_layer, + value=value_layer, + is_causal=True, + enable=self.config.use_rr_recompute and self.config.recompute, + ) + + new_context_layer_shape = context_layer.shape[:-2] + [ + self.all_head_size, + ] + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, None) if output_attentions else (context_layer,) + + return outputs + + +class BertSelfOutput(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor: + hidden_states = rr_no_recompute( + self.dense, hidden_states, enable=self.config.use_rr_recompute and self.config.recompute + ) + hidden_states = self.dropout(hidden_states) + + hidden_states = rr_no_recompute( + self.LayerNorm, hidden_states + input_tensor, enable=self.config.use_rr_recompute and self.config.recompute + ) + return hidden_states + + +class BertAttention(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=False) + self.dense.weight.main_grad = paddle.zeros_like(self.dense.weight).cast("float32") + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + def pylayer_dense(hidden_states): + return pylayer_matmul(hidden_states, self.dense.weight) + + hidden_states = rr_no_recompute( + pylayer_dense, hidden_states, enable=self.config.use_rr_recompute and self.config.recompute + ) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class BertOutput(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor: + def custom_dense(hidden_states, weight, bias=None): + return F.linear(hidden_states, weight, bias) + + bias = self.dense.bias * 1.1 + hidden_states = rr_no_recompute( + custom_dense, + hidden_states, + weight=self.dense.weight, + bias=bias, + enable=self.config.use_rr_recompute and self.config.recompute, + keys_ignore_to_save=["bias"], + ) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + # self attn + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # ffn + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + + outputs = (layer_output,) + outputs + + return outputs + + +class BertEncoder(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.LayerList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for layer_module in self.layer: + # add hidden_states + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.training and self.config.recompute: + recompute_function = rr_recompute if self.config.use_rr_recompute else original_recompute + layer_outputs = recompute_function( + layer_module, + hidden_states, + attention_mask, + output_attentions, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + output_attentions, + ) + hidden_states = layer_outputs[0] + + # add self attn + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + +class BertPreTrainedModel(nn.Layer): + def _init_weights(self, module): + """Initialize the weights""" + pass + + +class BertModel(BertPreTrainedModel): + def __init__(self, config): + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Tuple[paddle.Tensor]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if token_type_ids is None: + token_type_ids = paddle.zeros(input_ids.shape, dtype=paddle.int64) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + return encoder_outputs + + +class BertRefinedRecomputeTest(unittest.TestCase): + def no_pp_fwd_bwd( + self, + recompute=False, + use_rr_recompute=False, + recompute_use_reentrant=False, + num_hidden_layers=4, + shape=[2, 64], + ): + paddle.set_default_dtype(dtype) + paddle.seed(42) + config = BertConfig( + num_hidden_layers=num_hidden_layers, + recompute=recompute, + use_rr_recompute=use_rr_recompute, + recompute_use_reentrant=recompute_use_reentrant, + ) + model = BertModel(config) + model.train() + input_ids = paddle.randint(10, config.vocab_size, shape=shape) + gpu_mem_used_before = paddle.device.cuda.memory_allocated() + outputs = model(input_ids=input_ids)[0] + gpu_mem_used_after = paddle.device.cuda.memory_allocated() + outputs.sum().backward() + + # div = 1024**3 # GB + div = 1 # KB + return ( + model, + round((gpu_mem_used_after - gpu_mem_used_before) / div, 2), + round(paddle.device.cuda.max_memory_allocated() / div, 2), + ) + + @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute only support on gpu") + def test_refined_recompute(self): + raw_dtype = paddle.get_default_dtype() + + model1, mem_usage_forward1, max_mem_usage_forward1 = self.no_pp_fwd_bwd( + recompute=True, use_rr_recompute=False + ) # with recompute + model2, mem_usage_forward2, max_mem_usage_forward2 = self.no_pp_fwd_bwd( + recompute=True, use_rr_recompute=True + ) # with rr recompute + model3, mem_usage_forward3, max_mem_usage_forward3 = self.no_pp_fwd_bwd( + recompute=False, use_rr_recompute=False + ) # without recompute + + name_list = [n for n, _ in model1.named_parameters()] + + for param1, param2, name in zip(model1.parameters(), model3.parameters(), name_list): + # test main grad + if "intermediate.dense.weight" in name: + self.assertTrue(param1.main_grad.sum().item() > 0) + self.assertTrue(param2.main_grad.sum().item() > 0) + self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) + + for param1, param2, name in zip(model2.parameters(), model3.parameters(), name_list): + # test main grad + if "intermediate.dense.weight" in name: + self.assertTrue(param1.main_grad.sum().item() > 0) + self.assertTrue(param2.main_grad.sum().item() > 0) + self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) + + # self.assertTrue(mem_usage_forward1 < mem_usage_forward2 < mem_usage_forward3) + # self.assertTrue(max_mem_usage_forward1 < max_mem_usage_forward2 < max_mem_usage_forward3) + + del model1, model2, model3 + paddle.device.cuda.empty_cache() + paddle.set_default_dtype(raw_dtype) + + def pp_fwd_bwd( + self, + recompute=False, + use_rr_recompute=False, + recompute_use_reentrant=False, + num_iter=4, + shape=[2, 64], + ): + paddle.set_default_dtype(dtype) + paddle.seed(42) + config = BertConfig( + num_hidden_layers=1, + recompute=recompute, + use_rr_recompute=use_rr_recompute, + recompute_use_reentrant=recompute_use_reentrant, + ) + layer = BertLayer(config) + layer.train() + + x = paddle.randn([*shape, config.hidden_size]) + x.stop_gradient = False + x_copy = x + + if layer.training and config.recompute: + recompute_function = rr_recompute if config.use_rr_recompute else original_recompute + for _ in range(num_iter): + x = recompute_function(layer, x, use_reentrant=config.recompute_use_reentrant)[0] + else: + for _ in range(num_iter): + x = layer(x)[0] + + x.sum().backward() + + return x_copy.grad, layer + + @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") + def test_refined_recompute_pp(self): + raw_dtype = paddle.get_default_dtype() + grad1, layer1 = self.pp_fwd_bwd(recompute=True, use_rr_recompute=False) + grad2, layer2 = self.pp_fwd_bwd(recompute=True, use_rr_recompute=True) + grad3, layer3 = self.pp_fwd_bwd(recompute=False, use_rr_recompute=False) + + name_list = [n for n, _ in layer1.named_parameters()] + + for param1, param2, name in zip(layer1.parameters(), layer3.parameters(), name_list): + # test main grad + if "intermediate.dense.weight" in name: + self.assertTrue(param1.main_grad.sum().item() > 0) + self.assertTrue(param2.main_grad.sum().item() > 0) + self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) + + self.assertTrue(paddle.equal_all(grad1.cast("float32"), grad3.cast("float32"))) + for param1, param2, name in zip(layer2.parameters(), layer3.parameters(), name_list): + # test main grad + if "intermediate.dense.weight" in name: + self.assertTrue(param1.main_grad.sum().item() > 0) + self.assertTrue(param2.main_grad.sum().item() > 0) + self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) + + self.assertTrue(paddle.equal_all(grad2.cast("float32"), grad3.cast("float32"))) + + del grad1, grad2, grad3 + del layer1, layer2, layer3 + paddle.device.cuda.empty_cache() + paddle.set_default_dtype(raw_dtype)