Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA2] Add flash attention for opt #26414

Merged
merged 16 commits into from
Nov 23, 2023
243 changes: 235 additions & 8 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand All @@ -32,12 +33,18 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_opt import OPTConfig


if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
Expand All @@ -63,6 +70,19 @@
]


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
Expand Down Expand Up @@ -160,6 +180,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

Expand Down Expand Up @@ -273,17 +294,212 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


class OptFlashAttention2(OPTAttention):
"""
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of flash
attention and deal with padding tokens in case the input contains any of them.
"""

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)

_, query_length, _, _ = query_states.shape

attn_dropout = self.dropout if self.training else 0.0

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, query_length, dropout=attn_dropout
)

attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
attn_output = self.out_proj(attn_weights_reshaped)

if not output_attentions:
attn_weights_reshaped = None

return attn_output, attn_weights_reshaped, past_key_value

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if padding_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape

key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)

if not getattr(config, "_flash_attn_2_enabled", False):
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)
else:
self.self_attn = OptFlashAttention2(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)

self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
Expand All @@ -303,6 +519,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -333,6 +550,7 @@ def forward(
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
padding_mask=padding_mask,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -399,6 +617,7 @@ class OPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -642,11 +861,18 @@ def forward(
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
padding_mask = None
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None

causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
Expand Down Expand Up @@ -695,7 +921,7 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return module(*inputs, output_attentions, None, padding_mask=padding_mask)

return custom_forward

Expand All @@ -714,6 +940,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)

hidden_states = layer_outputs[0]
Expand Down
1 change: 0 additions & 1 deletion tests/models/opt/test_modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
else {}
)
is_encoder_decoder = False
fx_compatible = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should not be removed

Copy link
Contributor Author

@susnato susnato Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to get error for the fx tests - test_torch_fx and test_torch_fx_output_loss.
The error is happening for this line.

The full error for test_torch_fx

tests/models/opt/test_modeling_opt.py F [100%]

========================================================= FAILURES =========================================================
________________________________________________ OPTModelTest.test_torch_fx ________________________________________________

self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>
config = OPTConfig {
"_remove_final_layer_norm": false,
"activation_function": "relu",
"attention_dropout": 0.1,
"bos_t...d": 1,
"transformers_version": "4.34.0.dev0",
"use_cache": true,
"vocab_size": 99,
"word_embed_proj_dim": 16
}

inputs_dict = {'attention_mask': tensor([[True, True, True, True, True, True, True],
[True, True, True, True, True, True, Tr...91, 54, 98, 57, 55, 2],
[74, 79, 56, 51, 93, 26, 2],
[62, 18, 55, 3, 73, 74, 2]], device='cuda:0')}
output_loss = False

def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
    if not is_torch_fx_available() or not self.fx_compatible:
        return

    configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
    configs_no_init.return_dict = False

    for model_class in self.all_model_classes:
        model = model_class(config=configs_no_init)
        model.to(torch_device)
        model.eval()
        inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)

        try:
            if model.config.is_encoder_decoder:
                model.config.use_cache = False  # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
                labels = inputs.get("labels", None)
                input_names = [
                    "attention_mask",
                    "decoder_attention_mask",
                    "decoder_input_ids",
                    "input_features",
                    "input_ids",
                    "input_values",
                ]
                if labels is not None:
                    input_names.append("labels")

                filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
                input_names = list(filtered_inputs.keys())

                model_output = model(**filtered_inputs)

                traced_model = symbolic_trace(model, input_names)
                traced_output = traced_model(**filtered_inputs)
            else:
                input_names = [
                    "attention_mask",
                    "bbox",
                    "input_features",
                    "input_ids",
                    "input_values",
                    "pixel_values",
                    "token_type_ids",
                    "visual_feats",
                    "visual_pos",
                ]

                labels = inputs.get("labels", None)
                start_positions = inputs.get("start_positions", None)
                end_positions = inputs.get("end_positions", None)
                if labels is not None:
                    input_names.append("labels")
                if start_positions is not None:
                    input_names.append("start_positions")
                if end_positions is not None:
                    input_names.append("end_positions")

                filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
                input_names = list(filtered_inputs.keys())

                if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
                    not hasattr(model.config, "problem_type") or model.config.problem_type is None
                ):
                    model.config.problem_type = "single_label_classification"
              traced_model = symbolic_trace(model, input_names)

tests/test_modeling_common.py:877:


model = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_names = ['input_ids', 'attention_mask'], disable_check = False, tracer_cls = <class 'transformers.utils.fx.HFTracer'>

def symbolic_trace(
    model: PreTrainedModel,
    input_names: Optional[List[str]] = None,
    disable_check: bool = False,
    tracer_cls: Type[HFTracer] = HFTracer,
) -> GraphModule:
    """
    Performs symbolic tracing on the model.

    Args:
        model ([`PretrainedModel`]):
            The model to trace.
        input_names (`List[str]`, *optional*):
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
        tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
            The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.

    Returns:
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

        ```python
        from transformers.utils.fx import symbolic_trace

        traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
        ```
    """
    if input_names is None:
        input_names = model.dummy_inputs.keys()

    input_names = list(input_names)
    concrete_args = get_concrete_args(model, input_names)

    if not disable_check:
        check_if_model_is_supported(model)

    # Tracing.
    tracer = tracer_cls()
  traced_graph = tracer.trace(model, concrete_args=concrete_args)

src/transformers/utils/fx.py:1250:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}
dummy_inputs = None, complete_concrete_args_with_inputs_not_in_dummy_inputs = True

def trace(
    self,
    root: Union[torch.nn.Module, Callable[..., Any]],
    concrete_args: Optional[Dict[str, Any]] = None,
    dummy_inputs: Optional[Dict[str, Any]] = None,
    complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
    """
    Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
    `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
    the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
    `torch.nn.Module` instance to use as the root and add embedded constants to.

    Args:
        root (`torch.nn.Module` or  `Callable`):
            Either a `torch.nn.Module`` or a function to be traced through. If root is not a
            [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
        concrete_args (`Dict[str, Any], *optional*):
            Concrete arguments that should not be treated as Proxies
        dummy_inputs (`Dict[str, Any]`, *optional*):
            The dummy inputs needed to handle data-dependent control-flow if `root` is not a
            [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
            [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
        complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
            If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
            `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.

    Returns:
        `torch.fx.Graph`:
            A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.

    """
    sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)

    if concrete_args is None:
        concrete_args = {}

    if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
        for param in sig.parameters.values():
            if param.name in dummy_inputs:
                continue
            if param.default is inspect.Parameter.empty:
                raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
        concrete_args.update(
            {
                p.name: p.default
                for p in sig.parameters.values()
                if (p.name not in dummy_inputs and p.name not in concrete_args)
            }
        )

    input_names = sig.parameters.keys() - concrete_args.keys()

    # Creating a random input shape to generate dummy inputs.
    batch_size = _generate_random_int()
    sequence_length = _generate_random_int()
    shape = [batch_size, sequence_length]

    if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
        num_choices = _generate_random_int(low=2, high=5)
        shape.insert(1, num_choices)

    inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
    for input_name in input_names:
        if input_name in inputs:
            continue
        # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
        # be able to use HFTracer._generate_dummy_input.
        if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
            ("_deserialize_graph_module", "_CodeOnlyModule")
        ):
            inputs.update(self._generate_dummy_input(root, input_name, shape))
        else:
            raise RuntimeError(
                f"Could not generate input named {input_name} for because root is not a"
                " transformers.PreTrainedModel."
            )

    concrete_metas = {
        input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
        for input_name, input_ in inputs.items()
    }
    for param in sig.parameters.values():
        if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
            concrete_metas[f"**{param.name}"] = {}
    self.meta_args = concrete_metas
    self.patched_torch_methods = {
        target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
    }
    self.orig_fns = set()

    for name, (wrapper, orig) in self.patched_torch_methods.items():
        setattr(torch, name, wrapper)
        self.orig_fns.add(orig)

    try:
      self.graph = super().trace(root, concrete_args=concrete_args)

src/transformers/utils/fx.py:1088:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}

@compatibility(is_backward_compatible=True)
def trace(
    self,
    root: Union[torch.nn.Module, Callable[..., Any]],
    concrete_args: Optional[Dict[str, Any]] = None,
) -> Graph:
    """
    Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
    can either be an ``nn.Module`` instance or a Python callable.

    Note that after this call, ``self.root`` may be different from the ``root`` passed
    in here. For example, when a free function is passed to ``trace()``, we will
    create an ``nn.Module`` instance to use as the root and add embedded constants
    to.


    Args:

        root (Union[Module, Callable]): Either a ``Module`` or a function to be
            traced through. Backwards-compatibility for this parameter is
            guaranteed.
        concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
            not be treated as Proxies. This parameter is experimental and
            its backwards-compatibility is *NOT* guaranteed.

    Returns:

        A ``Graph`` representing the semantics of the passed-in ``root``.
    """
    global _is_fx_tracing_flag
    old_is_fx_tracing_flag = _is_fx_tracing_flag
    _is_fx_tracing_flag = True
    try:
        if isinstance(root, torch.nn.Module):
            self.root = root

            assert hasattr(
                type(root), self.traced_func_name
            ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"

            fn = getattr(type(root), self.traced_func_name)
            self.submodule_paths = {mod: name for name, mod in root.named_modules()}
        else:
            self.root = torch.nn.Module()
            fn = root

        tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)
        self.graph = Graph(tracer_cls=tracer_cls)

        # When we encounter a Tensor value that's not a parameter, we look if it
        # is some other attribute on the model. Construct a dict mapping Tensor
        # values to the qualified name here for efficiency. This is used downstream
        # in create_arg
        self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}

        def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
            for k, v in m.__dict__.items():
                if isinstance(v, (torch.Tensor, ScriptObject)):
                    self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
            for k, v in m.named_children():
                collect_tensor_attrs(v, prefix_atoms + [k])

        collect_tensor_attrs(self.root, [])

        assert isinstance(fn, FunctionType)

        fn_globals = fn.__globals__  # run before it gets patched
        fn, args = self.create_args_for_root(
            fn, isinstance(root, torch.nn.Module), concrete_args
        )

        parameter_proxy_cache: Dict[
            str, Proxy
        ] = {}  # Reduce number of get_attr calls

        # Method dispatch on parameters is not recorded unless it's directly used.
        # Thus, we need to insert a proxy when __getattr__ requests a parameter.
        @functools.wraps(_orig_module_getattr)
        def module_getattr_wrapper(mod, attr):
            attr_val = _orig_module_getattr(mod, attr)
            return self.getattr(attr, attr_val, parameter_proxy_cache)

        @functools.wraps(_orig_module_call)
        def module_call_wrapper(mod, *args, **kwargs):
            def forward(*args, **kwargs):
                return _orig_module_call(mod, *args, **kwargs)

            _autowrap_check(
                patcher,
                getattr(getattr(mod, "forward", mod), "__globals__", {}),
                self._autowrap_function_ids,
            )
            return self.call_module(mod, forward, args, kwargs)

        with _Patcher() as patcher:
            # allow duplicate patches to support the case of nested calls
            patcher.patch_method(
                torch.nn.Module,
                "__getattr__",
                module_getattr_wrapper,
                deduplicate=False,
            )
            patcher.patch_method(
                torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
            )
            _patch_wrapped_functions(patcher)
            _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
            for module in self._autowrap_search:
                _autowrap_check(
                    patcher, module.__dict__, self._autowrap_function_ids
                )
            self.create_node(
                "output",
                "output",
              (self.create_arg(fn(*args)),),
                {},
                type_expr=fn.__annotations__.get("return", None),
            )

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:739:


self = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_ids = Proxy(input_ids), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = None, use_cache = True, output_attentions = False, output_hidden_states = False, return_dict = False

@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=BaseModelOutputWithPast,
    config_class=_CONFIG_FOR_DOC,
    expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
  decoder_outputs = self.decoder(
        input_ids=input_ids,
        attention_mask=attention_mask,
        head_mask=head_mask,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

src/transformers/models/opt/modeling_opt.py:1023:


mod = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>

@functools.wraps(_orig_module_call)
def module_call_wrapper(mod, *args, **kwargs):
    def forward(*args, **kwargs):
        return _orig_module_call(mod, *args, **kwargs)

    _autowrap_check(
        patcher,
        getattr(getattr(mod, "forward", mod), "__globals__", {}),
        self._autowrap_function_ids,
    )
  return self.call_module(mod, forward, args, kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:717:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}

def call_module(self, m, forward, args, kwargs):
    self.orig_forward = forward
  return super().call_module(m, forward, args, kwargs)

src/transformers/utils/fx.py:987:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}

@compatibility(is_backward_compatible=True)
def call_module(
    self,
    m: torch.nn.Module,
    forward: Callable[..., Any],
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
) -> Any:
    """
    Method that specifies the behavior of this ``Tracer`` when it encounters
    a call to an ``nn.Module`` instance.

    By default, the behavior is to check if the called module is a leaf module
    via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
    ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
    the operations in its ``forward`` function.

    This method can be overridden to--for example--create nested traced
    GraphModules, or any other behavior you would want while tracing across
    ``Module`` boundaries.

    Args:

        m (Module): The module for which a call is being emitted
        forward (Callable): The forward() method of the ``Module`` to be invoked
        args (Tuple): args of the module callsite
        kwargs (Dict): kwargs of the module callsite

    Return:

        The return value from the Module call. In the case that a ``call_module``
        node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
        value was returned from the ``Module`` invocation.
    """
    module_qualified_name = self.path_of_module(m)
    if not self.is_leaf_module(m, module_qualified_name):
      return forward(*args, **kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:434:


args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}

def forward(*args, **kwargs):
  return _orig_module_call(mod, *args, **kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:710:


self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward_call = <bound method OPTDecoder.forward of OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions)...out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)>

def _call_impl(self, *input, **kwargs):
    forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
    # If we don't have any hooks, we want to skip the rest of the logic in
    # this function, and just call forward.
    if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
            or _global_forward_hooks or _global_forward_pre_hooks):
      return forward_call(*input, **kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194:


self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input_ids = Proxy(view), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = Proxy(decoder_embed_tokens), use_cache = True, output_attentions = False, output_hidden_states = False
return_dict = False

def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
    r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
            provide it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
            cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
            that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
            all `decoder_input_ids` of shape `(batch_size, sequence_length)`.

        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
            for more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
    """
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache

    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # retrieve input_ids and inputs_embeds
    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
    elif input_ids is not None:
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])
    elif inputs_embeds is not None:
        input_shape = inputs_embeds.size()[:-1]
    else:
        raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    batch_size, seq_length = input_shape
    past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
    # required mask seq length can be calculated via length of past
    mask_seq_length = past_key_values_length + seq_length

    # embed positions
    if attention_mask is None:
        attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
        padding_mask = None
    elif attention_mask.shape[1] != mask_seq_length:
        raise ValueError(
            f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
            f"{mask_seq_length} (sum of the lengths of current and past inputs)"
        )
    else:
      if 0 in attention_mask:

src/transformers/models/opt/modeling_opt.py:872:


self = Proxy(attention_mask), key = 0

def __contains__(self, key):
    if hasattr(self, "_metadata") and self._metadata is not None:
      return key in self._metadata

src/transformers/utils/fx.py:646:


self = tensor(..., device='meta', size=(14, 11), dtype=torch.int64), element = 0

def __contains__(self, element):
    r"""Check if `element` is present in tensor

    Args:
        element (Tensor or scalar): element to be checked
            for presence in current tensor"
    """
    if has_torch_function_unary(self):
        return handle_torch_function(Tensor.__contains__, (self,), self, element)
    if isinstance(element, (torch.Tensor, Number)):
        # type hint doesn't understand the __contains__ result array
      return (element == self).any().item()  # type: ignore[union-attr]

E NotImplementedError: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/_tensor.py:983: NotImplementedError

During handling of the above exception, another exception occurred:

self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>

def test_torch_fx(self):
    config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
  self._create_and_check_torch_fx_tracing(config, inputs_dict)

tests/test_modeling_common.py:805:


tests/test_modeling_common.py:882: in _create_and_check_torch_fx_tracing
self.fail(f"Couldn't trace module: {e}")
E AssertionError: Couldn't trace module: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]

EDIT : The errors are fixed now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I see, in that line can you use instead torch.isin ? https://pytorch.org/docs/stable/generated/torch.isin.html - suggestion from @michaelbenayoun

test_pruning = False
test_missing_keys = False

Expand Down