Skip to content

Commit

Permalink
Added VectorQuantizer base class (#8011)
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <ajukic@nvidia.com>
  • Loading branch information
anteju authored Jan 20, 2024
1 parent bb575b7 commit dfaf500
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 47 deletions.
78 changes: 47 additions & 31 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -340,7 +341,50 @@ def forward(self, audio_real, audio_gen):
return scores_real, scores_gen, fmaps_real, fmaps_gen


class FiniteScalarQuantizer(NeuralModule):
class VectorQuantizerBase(NeuralModule, ABC):
@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
}

@typecheck()
@abstractmethod
def forward(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pass

@typecheck(
input_types={
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"indices": NeuralType(('D', 'B', 'T'), Index())},
)
@abstractmethod
def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
pass

@typecheck(
input_types={
"indices": NeuralType(('D', 'B', 'T'), Index()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),},
)
@abstractmethod
def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
pass


class FiniteScalarQuantizer(VectorQuantizerBase):
"""This quantizer is based on the Finite Scalar Quantization (FSQ) method.
It quantizes each element of the input vector independently into a number of levels.
Expand Down Expand Up @@ -478,21 +522,7 @@ def codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor:
indices = torch.sum(indices * self.dim_base_index, dim=1)
return indices.to(torch.int32)

# API of the RVQ
@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
}

# Implementation of VectorQuantiserBase API
@typecheck()
def forward(
self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -556,7 +586,7 @@ def decode(self, indices: torch.Tensor, input_len: Optional[torch.Tensor] = None
return dequantized


class GroupFiniteScalarQuantizer(NeuralModule):
class GroupFiniteScalarQuantizer(VectorQuantizerBase):
"""Split the input vector into groups and apply FSQ on each group separately.
This class is for convenience. Since FSQ is applied on each group separately,
groups can be defined arbitrarily by splitting the input vector. However, this
Expand Down Expand Up @@ -604,20 +634,6 @@ def codebook_size(self):
"""Returns the size of the implicit codebook."""
return self.codebook_size_per_group ** self.num_groups

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
}

@typecheck()
def forward(self, inputs, input_len):
"""Quantize each group separately, then concatenate the results.
Expand Down
21 changes: 5 additions & 16 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
Conv1dNorm,
Conv2dNorm,
ConvTranspose1dNorm,
VectorQuantizerBase,
get_down_sample_padding,
)
from nemo.collections.tts.parts.utils.distributed import broadcast_tensors
Expand Down Expand Up @@ -690,7 +691,7 @@ def decode(self, indices, input_len):
return dequantized


class ResidualVectorQuantizer(NeuralModule):
class ResidualVectorQuantizer(VectorQuantizerBase):
"""
Residual vector quantization (RVQ) algorithm as described in https://arxiv.org/pdf/2107.03312.pdf.
Expand Down Expand Up @@ -732,13 +733,7 @@ def __init__(
]
)

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

# Override output types, since this quantizer returns commit_loss
@property
def output_types(self):
return {
Expand Down Expand Up @@ -818,7 +813,7 @@ def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
return dequantized


class GroupResidualVectorQuantizer(NeuralModule):
class GroupResidualVectorQuantizer(VectorQuantizerBase):
"""Split the input vector into groups and apply RVQ on each group separately.
Args:
Expand Down Expand Up @@ -875,13 +870,7 @@ def codebook_dim_per_group(self):

return self.codebook_dim // self.num_groups

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

# Override output types, since this quantizer returns commit_loss
@property
def output_types(self):
return {
Expand Down

0 comments on commit dfaf500

Please sign in to comment.