From 2f0b407d4a2e90e1a74bf1bd36e8c60db9e5f147 Mon Sep 17 00:00:00 2001 From: tianyuzhou668 <143938697+tianyuzhou668@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:03:51 +0800 Subject: [PATCH] [Llama] Support pp + no_recompute_layer. (#9373) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 周天宇 --- paddlenlp/transformers/llama/modeling_pp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 2efb06a903040a..f4598aec10142b 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -242,7 +242,12 @@ def forward(self, args): attn_mask_startend_row_indices = None has_gradient = not hidden_states.stop_gradient - if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + if ( + self.enable_recompute + and self.layerwise_recompute + and self.config.recompute_granularity == "full" + and has_gradient + ): if attention_mask is not None or alibi is not None or attn_mask_startend_row_indices is not None: hidden_states = recompute( super().forward, @@ -340,8 +345,6 @@ def __init__(self, config): self.recompute_granularity = self.config.recompute_granularity self.pp_recompute_interval = self.config.pp_recompute_interval self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] - if self.recompute_granularity == "full": - assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1)