|
23 | 23 | from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin |
24 | 24 | from ...tokenization_utils_base import AudioInput |
25 | 25 | from ...utils import is_torch_available, is_torchaudio_available, logging |
| 26 | +from ...utils.import_utils import requires_backends |
26 | 27 |
|
27 | 28 |
|
28 | 29 | logger = logging.get_logger(__name__) |
@@ -66,6 +67,8 @@ def __call__( |
66 | 67 | audios: AudioInput, |
67 | 68 | device: Optional[str] = "cpu", |
68 | 69 | ) -> BatchFeature: |
| 70 | + requires_backends(self, ["torchaudio"]) |
| 71 | + |
69 | 72 | speech_inputs = {} |
70 | 73 | batched_audio, audio_lengths = self._get_audios_and_audio_lengths(audios) |
71 | 74 | speech_inputs["input_features"] = self._extract_mel_spectrograms( |
@@ -95,15 +98,19 @@ def _ensure_melspec_transform_is_initialized(self): |
95 | 98 | We do this for now since some logging explodes since the mel spectrogram |
96 | 99 | transform is not JSON serializable. |
97 | 100 | """ |
| 101 | + requires_backends(self, ["torchaudio"]) |
| 102 | + |
98 | 103 | if self.melspec is None: |
99 | 104 | # TODO (@alex-jw-brooks / @eustlb) move this to common batch |
100 | 105 | # feature extraction in audio utils once they are written! |
101 | 106 | self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) |
102 | 107 |
|
103 | | - def _extract_mel_spectrograms(self, audio: torch.Tensor, device="cpu"): |
| 108 | + def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"): |
104 | 109 | """ |
105 | 110 | Compute the Mel features to be passed to the conformer encoder. |
106 | 111 | """ |
| 112 | + requires_backends(self, ["torchaudio"]) |
| 113 | + |
107 | 114 | # Initialize the mel spectrogram if isn't not already and |
108 | 115 | # move the melspec / audio to the computation device. |
109 | 116 | self._ensure_melspec_transform_is_initialized() |
@@ -156,14 +163,16 @@ def _get_num_audio_features(self, audio_lengths: Sequence[int]) -> Sequence[int] |
156 | 163 |
|
157 | 164 | return projector_lengths |
158 | 165 |
|
159 | | - def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence[torch.Tensor, Sequence[int]]: |
| 166 | + def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence["torch.Tensor", Sequence[int]]: |
160 | 167 | """ |
161 | 168 | Coerces audio inputs to torch tensors and extracts audio lengths prior to stacking. |
162 | 169 |
|
163 | 170 | Args: |
164 | 171 | audios (`AudioInput`): |
165 | 172 | Audio sequence, numpy array, or torch tensor. |
166 | 173 | """ |
| 174 | + requires_backends(self, ["torch"]) |
| 175 | + |
167 | 176 | # Coerce to PyTorch tensors if we have numpy arrays, since |
168 | 177 | # currently we have a dependency on torch/torchaudio anyway |
169 | 178 | if isinstance(audios, np.ndarray): |
|
0 commit comments