Skip to content

Commit fd2ea95

Browse files
authoredJun 27, 2023
hubert
1 parent 1a7f015 commit fd2ea95

File tree

4 files changed

+273
-0
lines changed

4 files changed

+273
-0
lines changed
 

‎hubert/__init__.py

Whitespace-only changes.

‎hubert/hubert_model.py

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import copy
2+
import random
3+
from typing import Optional, Tuple
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as t_func
8+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
9+
10+
11+
class Hubert(nn.Module):
12+
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
13+
super().__init__()
14+
self._mask = mask
15+
self.feature_extractor = FeatureExtractor()
16+
self.feature_projection = FeatureProjection()
17+
self.positional_embedding = PositionalConvEmbedding()
18+
self.norm = nn.LayerNorm(768)
19+
self.dropout = nn.Dropout(0.1)
20+
self.encoder = TransformerEncoder(
21+
nn.TransformerEncoderLayer(
22+
768, 12, 3072, activation="gelu", batch_first=True
23+
),
24+
12,
25+
)
26+
self.proj = nn.Linear(768, 256)
27+
28+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
29+
self.label_embedding = nn.Embedding(num_label_embeddings, 256)
30+
31+
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
32+
mask = None
33+
if self.training and self._mask:
34+
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
35+
x[mask] = self.masked_spec_embed.to(x.dtype)
36+
return x, mask
37+
38+
def encode(
39+
self, x: torch.Tensor, layer: Optional[int] = None
40+
) -> Tuple[torch.Tensor, torch.Tensor]:
41+
x = self.feature_extractor(x)
42+
x = self.feature_projection(x.transpose(1, 2))
43+
x, mask = self.mask(x)
44+
x = x + self.positional_embedding(x)
45+
x = self.dropout(self.norm(x))
46+
x = self.encoder(x, output_layer=layer)
47+
return x, mask
48+
49+
def logits(self, x: torch.Tensor) -> torch.Tensor:
50+
logits = torch.cosine_similarity(
51+
x.unsqueeze(2),
52+
self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
53+
dim=-1,
54+
)
55+
return logits / 0.1
56+
57+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
58+
x, mask = self.encode(x)
59+
x = self.proj(x)
60+
logits = self.logits(x)
61+
return logits, mask
62+
63+
64+
class HubertSoft(Hubert):
65+
def __init__(self):
66+
super().__init__()
67+
68+
@torch.inference_mode()
69+
def units(self, wav: torch.Tensor) -> torch.Tensor:
70+
wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
71+
x, _ = self.encode(wav)
72+
return self.proj(x)
73+
74+
75+
class FeatureExtractor(nn.Module):
76+
def __init__(self):
77+
super().__init__()
78+
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
79+
self.norm0 = nn.GroupNorm(512, 512)
80+
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
81+
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
82+
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
83+
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
84+
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
85+
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
86+
87+
def forward(self, x: torch.Tensor) -> torch.Tensor:
88+
x = t_func.gelu(self.norm0(self.conv0(x)))
89+
x = t_func.gelu(self.conv1(x))
90+
x = t_func.gelu(self.conv2(x))
91+
x = t_func.gelu(self.conv3(x))
92+
x = t_func.gelu(self.conv4(x))
93+
x = t_func.gelu(self.conv5(x))
94+
x = t_func.gelu(self.conv6(x))
95+
return x
96+
97+
98+
class FeatureProjection(nn.Module):
99+
def __init__(self):
100+
super().__init__()
101+
self.norm = nn.LayerNorm(512)
102+
self.projection = nn.Linear(512, 768)
103+
self.dropout = nn.Dropout(0.1)
104+
105+
def forward(self, x: torch.Tensor) -> torch.Tensor:
106+
x = self.norm(x)
107+
x = self.projection(x)
108+
x = self.dropout(x)
109+
return x
110+
111+
112+
class PositionalConvEmbedding(nn.Module):
113+
def __init__(self):
114+
super().__init__()
115+
self.conv = nn.Conv1d(
116+
768,
117+
768,
118+
kernel_size=128,
119+
padding=128 // 2,
120+
groups=16,
121+
)
122+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
123+
124+
def forward(self, x: torch.Tensor) -> torch.Tensor:
125+
x = self.conv(x.transpose(1, 2))
126+
x = t_func.gelu(x[:, :, :-1])
127+
return x.transpose(1, 2)
128+
129+
130+
class TransformerEncoder(nn.Module):
131+
def __init__(
132+
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
133+
) -> None:
134+
super(TransformerEncoder, self).__init__()
135+
self.layers = nn.ModuleList(
136+
[copy.deepcopy(encoder_layer) for _ in range(num_layers)]
137+
)
138+
self.num_layers = num_layers
139+
140+
def forward(
141+
self,
142+
src: torch.Tensor,
143+
mask: torch.Tensor = None,
144+
src_key_padding_mask: torch.Tensor = None,
145+
output_layer: Optional[int] = None,
146+
) -> torch.Tensor:
147+
output = src
148+
for layer in self.layers[:output_layer]:
149+
output = layer(
150+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
151+
)
152+
return output
153+
154+
155+
def _compute_mask(
156+
shape: Tuple[int, int],
157+
mask_prob: float,
158+
mask_length: int,
159+
device: torch.device,
160+
min_masks: int = 0,
161+
) -> torch.Tensor:
162+
batch_size, sequence_length = shape
163+
164+
if mask_length < 1:
165+
raise ValueError("`mask_length` has to be bigger than 0.")
166+
167+
if mask_length > sequence_length:
168+
raise ValueError(
169+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
170+
)
171+
172+
# compute number of masked spans in batch
173+
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
174+
num_masked_spans = max(num_masked_spans, min_masks)
175+
176+
# make sure num masked indices <= sequence_length
177+
if num_masked_spans * mask_length > sequence_length:
178+
num_masked_spans = sequence_length // mask_length
179+
180+
# SpecAugment mask to fill
181+
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
182+
183+
# uniform distribution to sample from, make sure that offset samples are < sequence_length
184+
uniform_dist = torch.ones(
185+
(batch_size, sequence_length - (mask_length - 1)), device=device
186+
)
187+
188+
# get random indices to mask
189+
mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
190+
191+
# expand masked indices to masked spans
192+
mask_indices = (
193+
mask_indices.unsqueeze(dim=-1)
194+
.expand((batch_size, num_masked_spans, mask_length))
195+
.reshape(batch_size, num_masked_spans * mask_length)
196+
)
197+
offsets = (
198+
torch.arange(mask_length, device=device)[None, None, :]
199+
.expand((batch_size, num_masked_spans, mask_length))
200+
.reshape(batch_size, num_masked_spans * mask_length)
201+
)
202+
mask_idxs = mask_indices + offsets
203+
204+
# scatter indices to mask
205+
mask = mask.scatter(1, mask_idxs, True)
206+
207+
return mask
208+
209+
210+
def hubert_soft(
211+
path: str,
212+
) -> HubertSoft:
213+
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
214+
Args:
215+
path (str): path of a pretrained model
216+
"""
217+
hubert = HubertSoft()
218+
checkpoint = torch.load(path)
219+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
220+
hubert.load_state_dict(checkpoint)
221+
hubert.eval()
222+
return hubert

‎hubert/inference.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
import sys
3+
current_dir = os.path.dirname(os.path.abspath(__file__))
4+
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
5+
sys.path.append(parent_dir)
6+
import numpy as np
7+
import argparse
8+
import torch
9+
10+
from whisper.audio import load_audio
11+
from hubert import hubert_model
12+
13+
14+
def load_model(path, device):
15+
model = hubert_model.hubert_soft(path)
16+
model.eval()
17+
model.half()
18+
model.to(device)
19+
return model
20+
21+
22+
def pred_vec(model, wavPath, vecPath, device):
23+
feats = load_audio(wavPath)
24+
feats = torch.from_numpy(feats).to(device)
25+
feats = feats[None, None, :].half()
26+
with torch.no_grad():
27+
vec = model.units(feats).squeeze().data.cpu().float().numpy()
28+
# print(vec.shape) # [length, dim=256] hop=320
29+
np.save(vecPath, vec, allow_pickle=False)
30+
31+
32+
if __name__ == "__main__":
33+
parser = argparse.ArgumentParser()
34+
parser.description = 'please enter embed parameter ...'
35+
parser.add_argument("-w", "--wav", help="wav", dest="wav")
36+
parser.add_argument("-v", "--vec", help="vec", dest="vec")
37+
args = parser.parse_args()
38+
print(args.wav)
39+
print(args.vec)
40+
41+
wavPath = args.wav
42+
vecPath = args.vec
43+
44+
assert torch.cuda.is_available()
45+
device = "cuda"
46+
hubert = load_model(os.path.join(
47+
"hubert_pretrain", "hubert-soft-0d54a1f4.pt"), device)
48+
pred_vec(hubert, wavPath, vecPath, device)

‎hubert_pretrain/README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Path for:
2+
3+
hubert-soft-0d54a1f4.pt

0 commit comments

Comments
 (0)
Please sign in to comment.