Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added VectorQuantizer base class #8011

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 47 additions & 31 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -340,7 +341,50 @@ def forward(self, audio_real, audio_gen):
return scores_real, scores_gen, fmaps_real, fmaps_gen


class FiniteScalarQuantizer(NeuralModule):
class VectorQuantizerBase(NeuralModule, ABC):
@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
}

@typecheck()
@abstractmethod
def forward(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pass

@typecheck(
input_types={
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"indices": NeuralType(('D', 'B', 'T'), Index())},
)
@abstractmethod
def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
pass

@typecheck(
input_types={
"indices": NeuralType(('D', 'B', 'T'), Index()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),},
)
@abstractmethod
def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
pass


class FiniteScalarQuantizer(VectorQuantizerBase):
"""This quantizer is based on the Finite Scalar Quantization (FSQ) method.
It quantizes each element of the input vector independently into a number of levels.

Expand Down Expand Up @@ -478,21 +522,7 @@ def codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor:
indices = torch.sum(indices * self.dim_base_index, dim=1)
return indices.to(torch.int32)

# API of the RVQ
@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
}

# Implementation of VectorQuantiserBase API
@typecheck()
def forward(
self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -556,7 +586,7 @@ def decode(self, indices: torch.Tensor, input_len: Optional[torch.Tensor] = None
return dequantized


class GroupFiniteScalarQuantizer(NeuralModule):
class GroupFiniteScalarQuantizer(VectorQuantizerBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the quantizers have the same interface then would it make sense to have just 1 grouped quantizer implementation/interface? I guess there would need to be a small amount of logic to specify how the underlying codebooks are created.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I had the same idea initially.
However, for now, I decided to leave them as separate implementations and only require them to conform with VectorQuantizerBase. The main reason is flexibility.

Unless there's a strong reason to do it right away, I'd suggest integrating this PR in the current form and revisit the group class in the future.

"""Split the input vector into groups and apply FSQ on each group separately.
This class is for convenience. Since FSQ is applied on each group separately,
groups can be defined arbitrarily by splitting the input vector. However, this
Expand Down Expand Up @@ -604,20 +634,6 @@ def codebook_size(self):
"""Returns the size of the implicit codebook."""
return self.codebook_size_per_group ** self.num_groups

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
}

@typecheck()
def forward(self, inputs, input_len):
"""Quantize each group separately, then concatenate the results.
Expand Down
21 changes: 5 additions & 16 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
Conv1dNorm,
Conv2dNorm,
ConvTranspose1dNorm,
VectorQuantizerBase,
get_down_sample_padding,
)
from nemo.collections.tts.parts.utils.distributed import broadcast_tensors
Expand Down Expand Up @@ -690,7 +691,7 @@ def decode(self, indices, input_len):
return dequantized


class ResidualVectorQuantizer(NeuralModule):
class ResidualVectorQuantizer(VectorQuantizerBase):
"""
Residual vector quantization (RVQ) algorithm as described in https://arxiv.org/pdf/2107.03312.pdf.

Expand Down Expand Up @@ -732,13 +733,7 @@ def __init__(
]
)

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

# Override output types, since this quantizer returns commit_loss
@property
def output_types(self):
return {
Expand Down Expand Up @@ -818,7 +813,7 @@ def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
return dequantized


class GroupResidualVectorQuantizer(NeuralModule):
class GroupResidualVectorQuantizer(VectorQuantizerBase):
"""Split the input vector into groups and apply RVQ on each group separately.

Args:
Expand Down Expand Up @@ -875,13 +870,7 @@ def codebook_dim_per_group(self):

return self.codebook_dim // self.num_groups

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

# Override output types, since this quantizer returns commit_loss
@property
def output_types(self):
return {
Expand Down
Loading