Skip to content

Commit

Permalink
[Llama] Support pp + no_recompute_layer. (PaddlePaddle#9373)
Browse files Browse the repository at this point in the history
Co-authored-by: 周天宇 <tianyu.zhou@iluvatar.com>
  • Loading branch information
tianyuzhou668 and 周天宇 authored Nov 6, 2024
1 parent 3971fc7 commit 2f0b407
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,12 @@ def forward(self, args):
attn_mask_startend_row_indices = None

has_gradient = not hidden_states.stop_gradient
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
if (
self.enable_recompute
and self.layerwise_recompute
and self.config.recompute_granularity == "full"
and has_gradient
):
if attention_mask is not None or alibi is not None or attn_mask_startend_row_indices is not None:
hidden_states = recompute(
super().forward,
Expand Down Expand Up @@ -340,8 +345,6 @@ def __init__(self, config):
self.recompute_granularity = self.config.recompute_granularity
self.pp_recompute_interval = self.config.pp_recompute_interval
self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else []
if self.recompute_granularity == "full":
assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support"

virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1)

Expand Down

0 comments on commit 2f0b407

Please sign in to comment.