Skip to content

Commit

Permalink
✏️ Add attention_mask in SpeechRecognitionDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed May 6, 2024
1 parent 4b5f446 commit b8a6479
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions hezar/data/datasets/speech_recognition_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ def __getitem__(self, index):
labels = self.tokenizer(
transcript,
max_length=self.config.labels_max_length,
return_tensors="pt"
)["token_ids"]
return_tensors="pt",
)

return {
"input_features": input_features,
"labels": labels
"labels": labels["token_ids"],
"attention_mask": labels["attention_mask"],
}

0 comments on commit b8a6479

Please sign in to comment.