-
Notifications
You must be signed in to change notification settings - Fork 9
/
vqwv2vec_feat_extract.py
54 lines (47 loc) · 1.6 KB
/
vqwv2vec_feat_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import fairseq
import torchaudio
import sys
import numpy as np
import librosa
from scipy.io import wavfile
audio_dir=sys.argv[1]
out_dir=sys.argv[2]
ckpt=sys.argv[3]
import glob
import os
from concurrent.futures import ProcessPoolExecutor
from functools import partial
import subprocess
from tqdm import tqdm
audio_paths = list(glob.glob(os.path.join(audio_dir,'*/*.wav')))
def load_wav(path):
sr, x = wavfile.read(path)
signed_int16_max = 2**15
if x.dtype == np.int16:
x = x.astype(np.float32) / signed_int16_max
print(f'24khz wav {x.shape}')
if sr != 16000:
x = librosa.resample(x, sr, 16000)
print(f'resample {x.shape}')
x = np.clip(x, -1.0, 1.0)
return x
def process(cp, _audio_paths, out_dir):
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
model = model[0]
model.eval()
for audio_path in _audio_paths:
wav = load_wav(audio_path)
wav_input_16khz = wav
wav_input_16khz = torch.FloatTensor(wav_input_16khz).unsqueeze(0)
z = model.feature_extractor(wav_input_16khz)
print(f"z {z.size()}")
dense, idxs = model.vector_quantizer.forward_idx(z)
dense = dense[0].data.numpy()
idxs = idxs[0].data.numpy()
print(f" dense {dense.shape} idxs {idxs.shape}")
file_id = os.path.basename(audio_path).split('.')[0]
spk=os.path.basename(os.path.dirname(audio_path))
os.makedirs(os.path.join(out_dir,spk), exist_ok = True)
np.save(os.path.join(out_dir, spk, file_id+'_dense'), dense)
process(ckpt, audio_paths, out_dir )