Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FX support for ConvNext, Wav2Vec2 and ResNet #19053

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def forward(self, hidden_states, mask_time_indices=None):
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def forward(self, hidden_states, mask_time_indices=None):
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _generate_supported_model_class_names(
"blenderbot-small",
"bloom",
"clip",
"convnext",
"deberta",
"deberta-v2",
"distilbert",
Expand All @@ -125,6 +126,7 @@ def _generate_supported_model_class_names(
"opt",
"pegasus",
"plbart",
"resnet",
"roberta",
"speech_to_text",
"speech_to_text_2",
Expand All @@ -133,6 +135,7 @@ def _generate_supported_model_class_names(
"trocr",
"vit",
"xglm",
"wav2vec2",
# "xlnet",
]

Expand Down Expand Up @@ -743,7 +746,7 @@ def _generate_dummy_input(
elif hasattr(model.config, "encoder"):
image_size = model.config.encoder.image_size
else:
raise AttributeError('Could not find the "image_size" field in the model config')
image_size = (_generate_random_int(), _generate_random_int())

# If no num_channels is in the config, use some arbitrary value.
num_channels = getattr(model.config, "num_channels", 3)
Expand Down
1 change: 1 addition & 0 deletions tests/models/convnext/test_modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)

fx_compatible = True
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
Expand Down
1 change: 1 addition & 0 deletions tests/models/resnet/test_modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):

all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()

fx_compatible = True
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
Expand Down
109 changes: 109 additions & 0 deletions tests/models/wav2vec2/test_modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
""" Testing suite for the PyTorch Wav2Vec2 model. """

import math
import os
import pickle
import tempfile
import unittest

import numpy as np
Expand All @@ -32,6 +35,7 @@
slow,
torch_device,
)
from transformers.utils import is_torch_fx_available

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
Expand Down Expand Up @@ -72,6 +76,10 @@
from transformers import Wav2Vec2ProcessorWithLM


if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace


class Wav2Vec2ModelTester:
def __init__(
self,
Expand Down Expand Up @@ -411,6 +419,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True
test_pruning = False
test_headmasking = False

Expand Down Expand Up @@ -633,6 +642,106 @@ def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)

# Wav2Vec2 cannot be torchscripted because of group norm.
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return

configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False

for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)

try:
input_names = [
"attention_mask",
"bbox",
"input_features",
"input_ids",
"input_values",
"pixel_values",
"token_type_ids",
"visual_feats",
"visual_pos",
]

labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")

filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())

model_output = model(**filtered_inputs)

if (
isinstance(model, Wav2Vec2ForSequenceClassification)
and not hasattr(model.config, "problem_type")
or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"

traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)

except Exception as e:
self.fail(f"Couldn't trace module: {e}")

def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten

model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)

for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)

# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")

loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)

for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)

# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()


@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
Expand Down