From 2b437db58a1f46f071524af8d0768cee19fa407a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 15 Sep 2022 11:34:41 +0200 Subject: [PATCH 1/4] Support for ConvNext --- src/transformers/utils/fx.py | 1 + tests/models/convnext/test_modeling_convnext.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index c08f6766c9dfc4..ebbde3a706a9f8 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -104,6 +104,7 @@ def _generate_supported_model_class_names( "blenderbot-small", "bloom", "clip", + "convnext", "deberta", "deberta-v2", "distilbert", diff --git a/tests/models/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py index 46ef3ce71709cc..a8157765a487d6 100644 --- a/tests/models/convnext/test_modeling_convnext.py +++ b/tests/models/convnext/test_modeling_convnext.py @@ -137,6 +137,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): else () ) + fx_ready = True test_pruning = False test_resize_embeddings = False test_head_masking = False From ce68710ee7c5c10e358b581e4fbdee6952dc1464 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 15 Sep 2022 14:08:29 +0200 Subject: [PATCH 2/4] Support for Wav2Vec2 --- .../models/wav2vec2/modeling_wav2vec2.py | 2 +- .../modeling_wav2vec2_conformer.py | 2 +- src/transformers/utils/fx.py | 1 + .../models/wav2vec2/test_modeling_wav2vec2.py | 109 ++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 9f678080039618..e1676399c14d08 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -960,7 +960,7 @@ def forward(self, hidden_states, mask_time_indices=None): # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) - codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 5bee0d040c8ba4..8723c6338d2d83 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1023,7 +1023,7 @@ def forward(self, hidden_states, mask_time_indices=None): # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) - codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index ebbde3a706a9f8..288798e9b02ed3 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -134,6 +134,7 @@ def _generate_supported_model_class_names( "trocr", "vit", "xglm", + "wav2vec2", # "xlnet", ] diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 21f77b19a553ca..040731472fe5bc 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -15,6 +15,9 @@ """ Testing suite for the PyTorch Wav2Vec2 model. """ import math +import os +import pickle +import tempfile import unittest import numpy as np @@ -32,6 +35,7 @@ slow, torch_device, ) +from transformers.utils import is_torch_fx_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( @@ -72,6 +76,10 @@ from transformers import Wav2Vec2ProcessorWithLM +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + + class Wav2Vec2ModelTester: def __init__( self, @@ -411,6 +419,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_compatible = True test_pruning = False test_headmasking = False @@ -633,6 +642,106 @@ def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsNotNone(model) + # Wav2Vec2 cannot be torchscripted because of group norm. + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + if ( + isinstance(model, Wav2Vec2ForSequenceClassification) + and not hasattr(model.config, "problem_type") + or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + + except Exception as e: + self.fail(f"Couldn't trace module: {e}") + + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + @require_torch class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): From c20ef406282b1b332e88da35529fe754ee4e634e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 15 Sep 2022 14:55:04 +0200 Subject: [PATCH 3/4] Support for Resnet --- src/transformers/utils/fx.py | 3 ++- tests/models/resnet/test_modeling_resnet.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 288798e9b02ed3..d3255baf847061 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -126,6 +126,7 @@ def _generate_supported_model_class_names( "opt", "pegasus", "plbart", + "resnet", "roberta", "speech_to_text", "speech_to_text_2", @@ -745,7 +746,7 @@ def _generate_dummy_input( elif hasattr(model.config, "encoder"): image_size = model.config.encoder.image_size else: - raise AttributeError('Could not find the "image_size" field in the model config') + image_size = (_generate_random_int(), _generate_random_int()) # If no num_channels is in the config, use some arbitrary value. num_channels = getattr(model.config, "num_channels", 3) diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index 83f08b68afb8be..557883e0b1ba9f 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -126,6 +126,7 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else () + fx_compatible = True test_pruning = False test_resize_embeddings = False test_head_masking = False From c0cf5e65cca733893a493ca51512bd6260299377 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 15 Sep 2022 17:33:47 +0200 Subject: [PATCH 4/4] Fix small issue in test_modeling_convnext --- tests/models/convnext/test_modeling_convnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py index a8157765a487d6..1225175a1b0641 100644 --- a/tests/models/convnext/test_modeling_convnext.py +++ b/tests/models/convnext/test_modeling_convnext.py @@ -137,7 +137,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): else () ) - fx_ready = True + fx_compatible = True test_pruning = False test_resize_embeddings = False test_head_masking = False