Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Add FlaxWhisperForAudioClassification model" #23154

Merged
merged 1 commit into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,3 @@ The original code can be found [here](https://github.com/openai/whisper).

[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__

## FlaxWhisperForAudioClassification

[[autodoc]] FlaxWhisperForAudioClassification
- __call__

8 changes: 1 addition & 7 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3779,7 +3779,6 @@
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]
)
_import_structure["models.xglm"].extend(
Expand Down Expand Up @@ -6904,12 +6903,7 @@
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@
]
)

FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("whisper", "FlaxWhisperForAudioClassification"),
]
)

FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]


Expand Down Expand Up @@ -127,7 +126,6 @@
pass
else:
from .modeling_flax_whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
Expand Down
160 changes: 0 additions & 160 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
Expand Down Expand Up @@ -1507,162 +1506,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
append_replace_return_docstrings(
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)


class FlaxWhisperForAudioClassificationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32

def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
self.config.is_encoder_decoder = False
num_layers = self.config.num_hidden_layers + 1
if self.config.use_weighted_layer_sum:
self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)

def __call__(
self,
input_features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states: bool = True,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

if self.config.use_weighted_layer_sum:
hidden_states = jnp.stack(encoder_outputs, axis=1)
norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = encoder_outputs[0]

hidden_states = self.projector(hidden_states)
pooled_output = jnp.mean(hidden_states, axis=1)

logits = self.classifier(pooled_output)

if not return_dict:
return (logits,) + encoder_outputs[1:]

return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
module_class = FlaxWhisperForAudioClassificationModule
dtype: jnp.dtype = jnp.float32

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

random_params = self.module.init(
rngs,
input_features=input_features,
)["params"]

if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params

@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict

# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng

return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
)


FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r"""
Returns:

Transcription example:

```python
>>> import jax.numpy as jnp
>>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification
>>> from datasets import load_dataset

>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> model = FlaxWhisperForAudioClassification.from_pretrained(
... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True
... )
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)

>>> sample = next(iter(ds))

>>> inputs = feature_extractor(
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np"
... )
>>> input_features = inputs.input_features

>>> logits = model(input_features).logits

>>> predicted_class_ids = jnp.argmax(logits).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
'af_za'
```
"""

overwrite_call_docstring(
FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
)
7 changes: 0 additions & 7 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,13 +1182,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxWhisperForAudioClassification(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]

Expand Down
Loading