Skip to content

Commit

Permalink
Wrap _prepare_4d_causal_attention_mask as a leaf function (#27236)
Browse files Browse the repository at this point in the history
Wrap _prepare_4d_causal_attention_mask as a leaf function
  • Loading branch information
michaelbenayoun authored Nov 2, 2023
1 parent 8a31295 commit 4557a0d
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
logging,
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_fx_available
from .configuration_llama import LlamaConfig


Expand All @@ -48,6 +49,12 @@
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LlamaConfig"
Expand Down

0 comments on commit 4557a0d

Please sign in to comment.