From b14084412d312485f7ae282d3cef8e4133f7db56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Jukic=CC=81?= Date: Mon, 18 Sep 2023 21:31:49 -0700 Subject: [PATCH] Group-residual vector quantizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- .../tts/modules/encodec_modules.py | 141 ++++++++++++++++++ .../tts/modules/test_audio_codec_modules.py | 93 ++++++++++++ 2 files changed, 234 insertions(+) diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index 8c424351ce35..c26bc5d9a31e 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -56,6 +56,7 @@ from nemo.core.classes.module import NeuralModule from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, LossType, VoidType from nemo.core.neural_types.neural_type import NeuralType +from nemo.utils import logging from nemo.utils.decorators import experimental @@ -807,3 +808,143 @@ def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: dequantized = dequantized + dequantized_i dequantized = rearrange(dequantized, "B T D -> B D T") return dequantized + + +class GroupResidualVectorQuantizer(NeuralModule): + """Split the input vector into groups and apply RVQ on each group separately. + + Args: + num_codebooks: total number of codebooks + num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks + codebook_dim: embedding dimension, will be split into num_groups + **kwargs: parameters of ResidualVectorQuantizer + + References: + Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765). + """ + + def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwargs): + super().__init__() + + self.num_codebooks = num_codebooks + self.num_groups = num_groups + self.codebook_dim = codebook_dim + + # Initialize RVQ for each group + self.rvqs = torch.nn.ModuleList( + [ + ResidualVectorQuantizer( + num_codebooks=self.num_codebooks_per_group, codebook_dim=self.codebook_dim_per_group, **kwargs + ) + for _ in range(self.num_groups) + ] + ) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tnum_codebooks: %d', self.num_codebooks) + logging.debug('\tnum_groups: %d', self.num_groups) + logging.debug('\tcodebook_dim: %d', self.codebook_dim) + logging.debug('\tnum_codebooks_per_group: %d', self.num_codebooks_per_group) + logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group) + + @property + def num_codebooks_per_group(self): + """Number of codebooks for each group. + """ + if self.num_codebooks % self.num_groups != 0: + raise ValueError( + f'num_codebooks ({self.num_codebooks}) must be divisible by num_groups ({self.num_groups})' + ) + + return self.num_codebooks // self.num_groups + + @property + def codebook_dim_per_group(self): + """Input vector dimension for each group. + """ + if self.codebook_dim % self.num_groups != 0: + raise ValueError(f'codebook_dim ({self.codebook_dim}) must be divisible by num_groups ({self.num_groups})') + + return self.codebook_dim // self.num_groups + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "indices": NeuralType(('D', 'B', 'T'), Index()), + "commit_loss": NeuralType((), LossType()), + } + + @typecheck() + def forward(self, inputs, input_len): + """Quantize each group separately, then concatenate the results. + """ + inputs_grouped = inputs.chunk(self.num_groups, dim=1) + + dequantized, indices = [], [] + commit_loss = 0 + + for in_group, rvq_group in zip(inputs_grouped, self.rvqs): + dequantized_group, indices_group, commit_loss_group = rvq_group(inputs=in_group, input_len=input_len) + dequantized.append(dequantized_group) + indices.append(indices_group) + commit_loss += commit_loss_group + + # concatenate along the feature dimension + dequantized = torch.cat(dequantized, dim=1) + + # concatente along the codebook dimension + indices = torch.cat(indices, dim=0) + + return dequantized, indices, commit_loss + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + ) + def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: + """Input is split into groups, each group is encoded separately, then the results are concatenated. + """ + inputs_grouped = inputs.chunk(self.num_groups, dim=1) + indices = [] + + for in_group, rvq_group in zip(inputs_grouped, self.rvqs): + indices_group = rvq_group.encode(inputs=in_group, input_len=input_len) + indices.append(indices_group) + + # concatenate along the codebook dimension + indices = torch.cat(indices, dim=0) + + return indices + + @typecheck( + input_types={ + "indices": NeuralType(('D', 'B', 'T'), Index()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + ) + def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: + """Input indices are split into groups, each group is decoded separately, then the results are concatenated. + """ + indices_grouped = indices.chunk(self.num_groups, dim=0) + dequantized = [] + + for indices_group, rvq_group in zip(indices_grouped, self.rvqs): + dequantized_group = rvq_group.decode(indices=indices_group, input_len=input_len) + dequantized.append(dequantized_group) + + # concatenate along the feature dimension + dequantized = torch.cat(dequantized, dim=1) + + return dequantized diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py index b48b415547fe..3b02552f3d73 100644 --- a/tests/collections/tts/modules/test_audio_codec_modules.py +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -16,6 +16,7 @@ import torch from nemo.collections.tts.modules.audio_codec_modules import Conv1dNorm, ConvTranspose1dNorm, get_down_sample_padding +from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer class TestAudioCodecModules: @@ -89,3 +90,95 @@ def test_conv1d_transpose_upsample(self): assert torch.all(out[0, :, out_len_1:] == 0.0) assert torch.all(out[1, :, :out_len_2] != 0.0) assert torch.all(out[1, :, out_len_2:] == 0.0) + + +class TestResidualVectorQuantizer: + def setup_class(self): + """Setup common members + """ + self.batch_size = 2 + self.max_len = 20 + self.codebook_size = 256 + self.codebook_dim = 64 + self.num_examples = 10 + + @pytest.mark.unit + @pytest.mark.parametrize('num_codebooks', [1, 4]) + def test_rvq_eval(self, num_codebooks: int): + """Simple test to confirm that the RVQ module can be instantiated and run, + and that forward produces the same result as encode-decode. + """ + # instantiate and set in eval mode + rvq = ResidualVectorQuantizer(num_codebooks=num_codebooks, codebook_dim=self.codebook_dim) + rvq.eval() + + for i in range(self.num_examples): + inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len]) + input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) + + # apply forward + dequantized_fw, indices_fw, commit_loss = rvq(inputs=inputs, input_len=input_len) + + # make sure the commit loss is zero + assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0' + + # encode-decode + indices_enc = rvq.encode(inputs=inputs, input_len=input_len) + dequantized_dec = rvq.decode(indices=indices_enc, input_len=input_len) + + # make sure the results are the same + torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') + torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') + + @pytest.mark.unit + @pytest.mark.parametrize('num_groups', [1, 2, 4]) + @pytest.mark.parametrize('num_codebooks', [1, 4]) + def test_group_rvq_eval(self, num_groups: int, num_codebooks: int): + """Simple test to confirm that the group RVQ module can be instantiated and run, + and that forward produces the same result as encode-decode. + """ + if num_groups > num_codebooks: + # Expected to fail if num_groups is lager than the total number of codebooks + with pytest.raises(ValueError): + _ = GroupResidualVectorQuantizer( + num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim + ) + else: + # Test inference with group RVQ + # instantiate and set in eval mode + grvq = GroupResidualVectorQuantizer( + num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim + ) + grvq.eval() + + for i in range(self.num_examples): + inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len]) + input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) + + # apply forward + dequantized_fw, indices_fw, commit_loss = grvq(inputs=inputs, input_len=input_len) + + # make sure the commit loss is zero + assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0' + + # encode-decode + indices_enc = grvq.encode(inputs=inputs, input_len=input_len) + dequantized_dec = grvq.decode(indices=indices_enc, input_len=input_len) + + # make sure the results are the same + torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') + torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') + + # apply individual RVQs and make sure the results are the same + inputs_grouped = inputs.chunk(num_groups, dim=1) + dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1) + indices_fw_grouped = indices_fw.chunk(num_groups, dim=0) + + for g in range(num_groups): + dequantized, indices, _ = grvq.rvqs[g](inputs=inputs_grouped[g], input_len=input_len) + torch.testing.assert_close( + dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}' + ) + torch.testing.assert_close( + indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}' + )