Skip to content

Commit

Permalink
Merge pull request #1 from m-bain/main
Browse files Browse the repository at this point in the history
Fix VAD Path for Custom VAD loading
  • Loading branch information
Swami-Abhinav authored Jan 4, 2024
2 parents 6bb2f1c + 8227807 commit a93ca91
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
]
+ [f"pyannote.audio==3.1.0"],
+ [f"pyannote.audio==3.1.1"],
entry_points={
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
},
Expand Down
4 changes: 2 additions & 2 deletions whisperx/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def __init__(
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)

def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
Expand Down

0 comments on commit a93ca91

Please sign in to comment.