Skip to content

Commit

Permalink
return attention mask in ASR pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Sep 16, 2024
1 parent 8f8af0f commit 8d0d13b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
truncation=False,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
)
else:
if self.type == "seq2seq_whisper" and stride is None:
Expand All @@ -448,13 +449,16 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt",
return_token_timestamps=True,
return_attention_mask=True,
)
extra["num_frames"] = processed.pop("num_frames")
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
inputs,
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt",
return_attention_mask=True,
)

if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
if stride is not None:
Expand Down

0 comments on commit 8d0d13b

Please sign in to comment.