Skip to content

Commit

Permalink
Revert "Add FlaxWhisperForAudioClassification model" (huggingface#23154)
Browse files Browse the repository at this point in the history
Revert "Add FlaxWhisperForAudioClassification model (huggingface#22883)"

This reverts commit c8f2c5c.
  • Loading branch information
sgugger authored and novice03 committed Jun 23, 2023
1 parent e3f058f commit 556242d
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 390 deletions.
6 changes: 0 additions & 6 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,3 @@ The original code can be found [here](https://github.com/openai/whisper).

[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__

## FlaxWhisperForAudioClassification

[[autodoc]] FlaxWhisperForAudioClassification
- __call__

8 changes: 1 addition & 7 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3793,7 +3793,6 @@
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]
)
_import_structure["models.xglm"].extend(
Expand Down Expand Up @@ -6930,12 +6929,7 @@
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@
]
)

FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("whisper", "FlaxWhisperForAudioClassification"),
]
)

FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]


Expand Down Expand Up @@ -127,7 +126,6 @@
pass
else:
from .modeling_flax_whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
Expand Down
160 changes: 0 additions & 160 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
Expand Down Expand Up @@ -1507,162 +1506,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
append_replace_return_docstrings(
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)


class FlaxWhisperForAudioClassificationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32

def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
self.config.is_encoder_decoder = False
num_layers = self.config.num_hidden_layers + 1
if self.config.use_weighted_layer_sum:
self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)

def __call__(
self,
input_features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states: bool = True,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

if self.config.use_weighted_layer_sum:
hidden_states = jnp.stack(encoder_outputs, axis=1)
norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = encoder_outputs[0]

hidden_states = self.projector(hidden_states)
pooled_output = jnp.mean(hidden_states, axis=1)

logits = self.classifier(pooled_output)

if not return_dict:
return (logits,) + encoder_outputs[1:]

return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
module_class = FlaxWhisperForAudioClassificationModule
dtype: jnp.dtype = jnp.float32

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

random_params = self.module.init(
rngs,
input_features=input_features,
)["params"]

if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params

@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict

# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng

return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
)


FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r"""
Returns:
Transcription example:
```python
>>> import jax.numpy as jnp
>>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> model = FlaxWhisperForAudioClassification.from_pretrained(
... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True
... )
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
>>> sample = next(iter(ds))
>>> inputs = feature_extractor(
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np"
... )
>>> input_features = inputs.input_features
>>> logits = model(input_features).logits
>>> predicted_class_ids = jnp.argmax(logits).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
'af_za'
```
"""

overwrite_call_docstring(
FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
)
7 changes: 0 additions & 7 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,13 +1182,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxWhisperForAudioClassification(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]

Expand Down
Loading

0 comments on commit 556242d

Please sign in to comment.