Skip to content

Commit

Permalink
Enable fx tracing for Mistral (#30209)
Browse files Browse the repository at this point in the history
* tracing for mistral

* typo

* fix copies
  • Loading branch information
zucchini-nlp authored and ydshieh committed Apr 23, 2024
1 parent 8b13656 commit 8083fca
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 6 deletions.
3 changes: 0 additions & 3 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,9 +868,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

if top_x.shape[0] == 0:
continue

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

if top_x.shape[0] == 0:
continue

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,16 @@ def _generate_supported_model_class_names(
"marian",
"mbart",
"megatron-bert",
"mistral",
"mixtral",
"mobilebert",
"mt5",
"nezha",
"opt",
"pegasus",
"plbart",
"qwen2",
"qwen2_moe",
"resnet",
"roberta",
"segformer",
Expand Down Expand Up @@ -758,6 +762,7 @@ class HFTracer(Tracer):
"tensor",
"clamp",
"finfo",
"tril",
]
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)

Expand Down
1 change: 1 addition & 0 deletions tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
)
test_headmasking = False
test_pruning = False
fx_compatible = True

# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
Expand Down
1 change: 1 addition & 0 deletions tests/models/mixtral/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
)
test_headmasking = False
test_pruning = False
fx_compatible = True

# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
Expand Down
1 change: 1 addition & 0 deletions tests/models/qwen2/test_modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
fx_compatible = True

# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
Expand Down
1 change: 1 addition & 0 deletions tests/models/qwen2_moe/test_modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
)
test_headmasking = False
test_pruning = False
fx_compatible = True

# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
Expand Down

0 comments on commit 8083fca

Please sign in to comment.