Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group-residual vector quantizer #7643

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, LossType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging
from nemo.utils.decorators import experimental


Expand Down Expand Up @@ -807,3 +808,143 @@ def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
dequantized = dequantized + dequantized_i
dequantized = rearrange(dequantized, "B T D -> B D T")
return dequantized


class GroupResidualVectorQuantizer(NeuralModule):
"""Split the input vector into groups and apply RVQ on each group separately.

Args:
num_codebooks: total number of codebooks
num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks
codebook_dim: embedding dimension, will be split into num_groups
**kwargs: parameters of ResidualVectorQuantizer

References:
Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765).
"""

def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwargs):
super().__init__()

self.num_codebooks = num_codebooks
self.num_groups = num_groups
self.codebook_dim = codebook_dim

# Initialize RVQ for each group
self.rvqs = torch.nn.ModuleList(
[
ResidualVectorQuantizer(
num_codebooks=self.num_codebooks_per_group, codebook_dim=self.codebook_dim_per_group, **kwargs
)
for _ in range(self.num_groups)
]
)

logging.debug('Initialized %s with', self.__class__.__name__)
logging.debug('\tnum_codebooks: %d', self.num_codebooks)
logging.debug('\tnum_groups: %d', self.num_groups)
logging.debug('\tcodebook_dim: %d', self.codebook_dim)
logging.debug('\tnum_codebooks_per_group: %d', self.num_codebooks_per_group)
logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group)

@property
def num_codebooks_per_group(self):
"""Number of codebooks for each group.
"""
if self.num_codebooks % self.num_groups != 0:
raise ValueError(
f'num_codebooks ({self.num_codebooks}) must be divisible by num_groups ({self.num_groups})'
)

return self.num_codebooks // self.num_groups

@property
def codebook_dim_per_group(self):
"""Input vector dimension for each group.
"""
if self.codebook_dim % self.num_groups != 0:
raise ValueError(f'codebook_dim ({self.codebook_dim}) must be divisible by num_groups ({self.num_groups})')

return self.codebook_dim // self.num_groups

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
"commit_loss": NeuralType((), LossType()),
}

@typecheck()
def forward(self, inputs, input_len):
"""Quantize each group separately, then concatenate the results.
"""
inputs_grouped = inputs.chunk(self.num_groups, dim=1)

dequantized, indices = [], []
commit_loss = 0

for in_group, rvq_group in zip(inputs_grouped, self.rvqs):
dequantized_group, indices_group, commit_loss_group = rvq_group(inputs=in_group, input_len=input_len)
dequantized.append(dequantized_group)
indices.append(indices_group)
commit_loss += commit_loss_group

# concatenate along the feature dimension
dequantized = torch.cat(dequantized, dim=1)

# concatente along the codebook dimension
indices = torch.cat(indices, dim=0)

return dequantized, indices, commit_loss

@typecheck(
input_types={
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"indices": NeuralType(('D', 'B', 'T'), Index())},
)
def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor:
"""Input is split into groups, each group is encoded separately, then the results are concatenated.
"""
inputs_grouped = inputs.chunk(self.num_groups, dim=1)
indices = []

for in_group, rvq_group in zip(inputs_grouped, self.rvqs):
indices_group = rvq_group.encode(inputs=in_group, input_len=input_len)
indices.append(indices_group)

# concatenate along the codebook dimension
indices = torch.cat(indices, dim=0)

return indices

@typecheck(
input_types={
"indices": NeuralType(('D', 'B', 'T'), Index()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),},
)
def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
"""Input indices are split into groups, each group is decoded separately, then the results are concatenated.
"""
indices_grouped = indices.chunk(self.num_groups, dim=0)
dequantized = []

for indices_group, rvq_group in zip(indices_grouped, self.rvqs):
dequantized_group = rvq_group.decode(indices=indices_group, input_len=input_len)
dequantized.append(dequantized_group)

# concatenate along the feature dimension
dequantized = torch.cat(dequantized, dim=1)

return dequantized
93 changes: 93 additions & 0 deletions tests/collections/tts/modules/test_audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch

from nemo.collections.tts.modules.audio_codec_modules import Conv1dNorm, ConvTranspose1dNorm, get_down_sample_padding
from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer


class TestAudioCodecModules:
Expand Down Expand Up @@ -89,3 +90,95 @@ def test_conv1d_transpose_upsample(self):
assert torch.all(out[0, :, out_len_1:] == 0.0)
assert torch.all(out[1, :, :out_len_2] != 0.0)
assert torch.all(out[1, :, out_len_2:] == 0.0)


class TestResidualVectorQuantizer:
def setup_class(self):
"""Setup common members
"""
self.batch_size = 2
self.max_len = 20
self.codebook_size = 256
self.codebook_dim = 64
self.num_examples = 10

@pytest.mark.unit
@pytest.mark.parametrize('num_codebooks', [1, 4])
def test_rvq_eval(self, num_codebooks: int):
"""Simple test to confirm that the RVQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
# instantiate and set in eval mode
rvq = ResidualVectorQuantizer(num_codebooks=num_codebooks, codebook_dim=self.codebook_dim)
rvq.eval()

for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)

# apply forward
dequantized_fw, indices_fw, commit_loss = rvq(inputs=inputs, input_len=input_len)

# make sure the commit loss is zero
assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0'

# encode-decode
indices_enc = rvq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = rvq.decode(indices=indices_enc, input_len=input_len)

# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')

@pytest.mark.unit
@pytest.mark.parametrize('num_groups', [1, 2, 4])
@pytest.mark.parametrize('num_codebooks', [1, 4])
def test_group_rvq_eval(self, num_groups: int, num_codebooks: int):
"""Simple test to confirm that the group RVQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
if num_groups > num_codebooks:
# Expected to fail if num_groups is lager than the total number of codebooks
with pytest.raises(ValueError):
_ = GroupResidualVectorQuantizer(
num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim
)
else:
# Test inference with group RVQ
# instantiate and set in eval mode
grvq = GroupResidualVectorQuantizer(
num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim
)
grvq.eval()

for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)

# apply forward
dequantized_fw, indices_fw, commit_loss = grvq(inputs=inputs, input_len=input_len)

# make sure the commit loss is zero
assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0'

# encode-decode
indices_enc = grvq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = grvq.decode(indices=indices_enc, input_len=input_len)

# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')

# apply individual RVQs and make sure the results are the same
inputs_grouped = inputs.chunk(num_groups, dim=1)
dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1)
indices_fw_grouped = indices_fw.chunk(num_groups, dim=0)

for g in range(num_groups):
dequantized, indices, _ = grvq.rvqs[g](inputs=inputs_grouped[g], input_len=input_len)
torch.testing.assert_close(
dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}'
)
torch.testing.assert_close(
indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}'
)
Loading