Skip to content

Commit 677b4e5

Browse files
Add backend guards in feature extractor
1 parent 93486bf commit 677b4e5

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/transformers/models/granite_speech/feature_extraction_granite_speech.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
2424
from ...tokenization_utils_base import AudioInput
2525
from ...utils import is_torch_available, is_torchaudio_available, logging
26+
from ...utils.import_utils import requires_backends
2627

2728

2829
logger = logging.get_logger(__name__)
@@ -66,6 +67,8 @@ def __call__(
6667
audios: AudioInput,
6768
device: Optional[str] = "cpu",
6869
) -> BatchFeature:
70+
requires_backends(self, ["torchaudio"])
71+
6972
speech_inputs = {}
7073
batched_audio, audio_lengths = self._get_audios_and_audio_lengths(audios)
7174
speech_inputs["input_features"] = self._extract_mel_spectrograms(
@@ -95,15 +98,19 @@ def _ensure_melspec_transform_is_initialized(self):
9598
We do this for now since some logging explodes since the mel spectrogram
9699
transform is not JSON serializable.
97100
"""
101+
requires_backends(self, ["torchaudio"])
102+
98103
if self.melspec is None:
99104
# TODO (@alex-jw-brooks / @eustlb) move this to common batch
100105
# feature extraction in audio utils once they are written!
101106
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
102107

103-
def _extract_mel_spectrograms(self, audio: torch.Tensor, device="cpu"):
108+
def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"):
104109
"""
105110
Compute the Mel features to be passed to the conformer encoder.
106111
"""
112+
requires_backends(self, ["torchaudio"])
113+
107114
# Initialize the mel spectrogram if isn't not already and
108115
# move the melspec / audio to the computation device.
109116
self._ensure_melspec_transform_is_initialized()
@@ -156,14 +163,16 @@ def _get_num_audio_features(self, audio_lengths: Sequence[int]) -> Sequence[int]
156163

157164
return projector_lengths
158165

159-
def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence[torch.Tensor, Sequence[int]]:
166+
def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence["torch.Tensor", Sequence[int]]:
160167
"""
161168
Coerces audio inputs to torch tensors and extracts audio lengths prior to stacking.
162169
163170
Args:
164171
audios (`AudioInput`):
165172
Audio sequence, numpy array, or torch tensor.
166173
"""
174+
requires_backends(self, ["torch"])
175+
167176
# Coerce to PyTorch tensors if we have numpy arrays, since
168177
# currently we have a dependency on torch/torchaudio anyway
169178
if isinstance(audios, np.ndarray):

0 commit comments

Comments
 (0)