From 07da90a54eb6e2e49b74f571ea55f2d60d60a49c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 4 May 2023 13:46:54 -0400 Subject: [PATCH] Revert "Add FlaxWhisperForAudioClassification model (#22883)" This reverts commit c8f2c5c56e942e8c45821d07555f2eab178b3f83. --- docs/source/en/model_doc/whisper.mdx | 6 - src/transformers/__init__.py | 8 +- .../models/auto/modeling_flax_auto.py | 5 - src/transformers/models/whisper/__init__.py | 2 - .../models/whisper/modeling_flax_whisper.py | 160 -------------- src/transformers/utils/dummy_flax_objects.py | 7 - .../whisper/test_modeling_flax_whisper.py | 205 +----------------- 7 files changed, 3 insertions(+), 390 deletions(-) diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 52a8b5953c6..22b08e4e61b 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -105,9 +105,3 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] FlaxWhisperForConditionalGeneration - __call__ - -## FlaxWhisperForAudioClassification - -[[autodoc]] FlaxWhisperForAudioClassification - - __call__ - diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b0766b0946c..7bf322ca8e1 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3779,7 +3779,6 @@ "FlaxWhisperForConditionalGeneration", "FlaxWhisperModel", "FlaxWhisperPreTrainedModel", - "FlaxWhisperForAudioClassification", ] ) _import_structure["models.xglm"].extend( @@ -6904,12 +6903,7 @@ FlaxWav2Vec2Model, FlaxWav2Vec2PreTrainedModel, ) - from .models.whisper import ( - FlaxWhisperForAudioClassification, - FlaxWhisperForConditionalGeneration, - FlaxWhisperModel, - FlaxWhisperPreTrainedModel, - ) + from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xlm_roberta import ( FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index e3b8d9cf5b5..755d1f07a34 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -229,11 +229,6 @@ ] ) -FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - ("whisper", "FlaxWhisperForAudioClassification"), - ] -) FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index cd962478e34..3b6015a56f6 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -75,7 +75,6 @@ "FlaxWhisperForConditionalGeneration", "FlaxWhisperModel", "FlaxWhisperPreTrainedModel", - "FlaxWhisperForAudioClassification", ] @@ -127,7 +126,6 @@ pass else: from .modeling_flax_whisper import ( - FlaxWhisperForAudioClassification, FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel, diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index e36131680d6..b8d6f07242d 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -36,7 +36,6 @@ FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, - FlaxSequenceClassifierOutput, ) from ...modeling_flax_utils import ( ACT2FN, @@ -1507,162 +1506,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): append_replace_return_docstrings( FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC ) - - -class FlaxWhisperForAudioClassificationModule(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype) - self.config.is_encoder_decoder = False - num_layers = self.config.num_hidden_layers + 1 - if self.config.use_weighted_layer_sum: - self.layer_weights = jnp.repeat(1 / num_layers, num_layers) - self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_features, - encoder_outputs=None, - output_attentions=None, - output_hidden_states: bool = True, - return_dict: bool = True, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_features, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if self.config.use_weighted_layer_sum: - hidden_states = jnp.stack(encoder_outputs, axis=1) - norm_weights = jax.nn.softmax(self.layer_weights, axis=-1) - hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1) - else: - hidden_states = encoder_outputs[0] - - hidden_states = self.projector(hidden_states) - pooled_output = jnp.mean(hidden_states, axis=1) - - logits = self.classifier(pooled_output) - - if not return_dict: - return (logits,) + encoder_outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING) -class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel): - module_class = FlaxWhisperForAudioClassificationModule - dtype: jnp.dtype = jnp.float32 - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_features = jnp.zeros(input_shape, dtype="f4") - input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_features=input_features, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - def __call__( - self, - input_features: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - **kwargs, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - input_features=jnp.array(input_features, dtype="f4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - -FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r""" - Returns: - - Transcription example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification - >>> from datasets import load_dataset - - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") - >>> model = FlaxWhisperForAudioClassification.from_pretrained( - ... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True - ... ) - >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) - - >>> sample = next(iter(ds)) - - >>> inputs = feature_extractor( - ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np" - ... ) - >>> input_features = inputs.input_features - - >>> logits = model(input_features).logits - - >>> predicted_class_ids = jnp.argmax(logits).item() - >>> predicted_label = model.config.id2label[predicted_class_ids] - >>> predicted_label - 'af_za' - ``` -""" - -overwrite_call_docstring( - FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC -) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index ce571bc9f8d..eeec3277492 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -1182,13 +1182,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxWhisperForAudioClassification(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - class FlaxWhisperForConditionalGeneration(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 79a2c51039a..3f1e201d72d 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import functools import inspect import tempfile @@ -39,7 +41,6 @@ from transformers import ( FLAX_MODEL_MAPPING, - FlaxWhisperForAudioClassification, FlaxWhisperForConditionalGeneration, FlaxWhisperModel, WhisperFeatureExtractor, @@ -703,205 +704,3 @@ def test_tiny_timestamp_generation(self): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - - -class FlaxWhisperEncoderModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=60, - is_training=True, - use_labels=True, - hidden_size=16, - num_hidden_layers=2, - num_attention_heads=4, - input_channels=1, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=20, - max_source_positions=30, - num_mel_bins=80, - num_conv_layers=1, - suppress_tokens=None, - begin_suppress_tokens=None, - classifier_proj_size=4, - num_labels=2, - is_encoder_decoder=False, - is_decoder=False, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_labels = use_labels - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.input_channels = input_channels - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.num_mel_bins = num_mel_bins - self.max_position_embeddings = max_position_embeddings - self.max_source_positions = max_source_positions - self.num_conv_layers = num_conv_layers - self.suppress_tokens = suppress_tokens - self.begin_suppress_tokens = begin_suppress_tokens - self.classifier_proj_size = classifier_proj_size - self.num_labels = num_labels - self.is_encoder_decoder = is_encoder_decoder - self.is_decoder = is_decoder - - def get_config(self): - return WhisperConfig( - d_model=self.hidden_size, - encoder_layers=self.num_hidden_layers, - decoder_layers=self.num_hidden_layers, - encoder_attention_heads=self.num_attention_heads, - decoder_attention_heads=self.num_attention_heads, - input_channels=self.input_channels, - dropout=self.hidden_dropout_prob, - attention_dropout=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - max_source_positions=self.max_source_positions, - decoder_ffn_dim=self.hidden_size, - encoder_ffn_dim=self.hidden_size, - suppress_tokens=self.suppress_tokens, - begin_suppress_tokens=self.begin_suppress_tokens, - classifier_proj_size=self.classifier_proj_size, - num_labels=self.num_labels, - is_encoder_decoder=self.is_encoder_decoder, - is_decoder=self.is_decoder, - ) - - def prepare_whisper_encoder_inputs_dict( - self, - input_features, - ): - return { - "input_features": input_features, - } - - def prepare_config_and_inputs(self): - input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length]) - - config = self.get_config() - inputs_dict = self.prepare_whisper_encoder_inputs_dict( - input_features=input_features, - ) - return config, inputs_dict - - def prepare_config_and_inputs_for_common(self): - config, inputs_dict = self.prepare_config_and_inputs() - return config, inputs_dict - - def get_subsampled_output_lengths(self, input_lengths): - """ - Computes the output length of the convolutional layers - """ - - for i in range(self.num_conv_layers): - input_lengths = (input_lengths - 1) // 2 + 1 - - return input_lengths - - @property - def encoder_seq_length(self): - return self.get_subsampled_output_lengths(self.seq_length) - - -@require_flax -class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase): - all_model_classes = (FlaxWhisperForAudioClassification,) if is_flax_available() else () - is_encoder_decoder = False - fx_compatible = False - test_pruning = False - test_missing_keys = False - - input_name = "input_features" - - def setUp(self): - self.model_tester = FlaxWhisperEncoderModelTester(self) - _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - self.init_shape = (1,) + inputs_dict["input_features"].shape[1:] - - self.all_model_classes = ( - make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes - ) - self.config_tester = ConfigTester(self, config_class=WhisperConfig) - - def test_config(self): - self.config_tester.run_common_tests() - - # overwrite because of `input_features` - def test_jit_compilation(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config) - - @jax.jit - def model_jitted(input_features, **kwargs): - return model(input_features=input_features, **kwargs) - - with self.subTest("JIT Enabled"): - jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() - - with self.subTest("JIT Disabled"): - with jax.disable_jit(): - outputs = model_jitted(**prepared_inputs_dict).to_tuple() - - self.assertEqual(len(outputs), len(jitted_outputs)) - for jitted_output, output in zip(jitted_outputs, outputs): - self.assertEqual(jitted_output.shape, output.shape) - - # overwrite because of `input_features` - def test_forward_signature(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - signature = inspect.signature(model.__call__) - # signature.parameters is an OrderedDict => so arg_names order is deterministic - arg_names = [*signature.parameters.keys()] - - expected_arg_names = ["input_features", "attention_mask", "output_attentions"] - self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) - - def test_inputs_embeds(self): - pass - - # WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented - def test_model_common_attributes(self): - pass - - # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings - def test_resize_tokens_embeddings(self): - pass - - # WhisperEncoder does not have any base model - def test_save_load_to_base(self): - pass - - # WhisperEncoder does not have any base model - def test_save_load_from_base(self): - pass - - # WhisperEncoder does not have any base model - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - pass - - # WhisperEncoder does not have any base model - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - pass - - # WhisperEncoder does not have any base model - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - pass