forked from facebookresearch/multimodal
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added quantisation layer, commitment loss, and associated unit tests (f…
…acebookresearch#29) Summary: Vector Quantized Variational Autoencoder (VQVAE) is widely used to generate multimodal data (audio, image, or video samples). A well-known model implementing this is OpenAI's DALL-E model, which can generate images from text captions. The core component of the VQVAE is the quantisation layer, which is a discrete embedding of vectors that codifies the outputs of an encoder. Here, this layer is implemented as its own module, along with commitment loss which is used to fine-tune the encoder. Pull Request resolved: facebookresearch#29 Test Plan: Unit tests were created and added in this PR. Testing was done with PyTest. `pytest -vv test/modules/losses/test_commitment.py` <img width="566" alt="image" src="https://user-images.githubusercontent.com/33648637/165840071-bac97cc4-1534-486c-a86d-1e65d50141cc.png"> `pytest -vv test/modules/layers/test_quantisation.py` <img width="565" alt="image" src="https://user-images.githubusercontent.com/33648637/165840175-decd1098-0d5a-418c-8806-0bf535999229.png"> Reviewed By: langong347 Differential Revision: D36030926 Pulled By: RdoubleA fbshipit-source-id: 6478df6d9cd3f07a48cb4a9ab492cde1679cf1b6
- Loading branch information
1 parent
28d222a
commit 292219e
Showing
4 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
from torchmultimodal.modules.layers.quantisation import Quantisation | ||
|
||
|
||
class TestQuantisation(unittest.TestCase): | ||
""" | ||
Test the Quantisation class | ||
""" | ||
|
||
def setUp(self): | ||
self.num_embeddings = 4 | ||
self.embedding_dim = 5 | ||
|
||
# This is 2x5x3 | ||
self.encoded = torch.Tensor( | ||
[ | ||
[[-1, 0, 1], [2, 1, 0], [0, -1, -1], [0, 2, -1], [-2, -1, 1]], | ||
[[2, 2, -1], [1, -1, -2], [0, 0, 0], [1, 2, 1], [1, 0, 0]], | ||
] | ||
) | ||
# This is 4x5 | ||
self.embedding_weights = torch.Tensor( | ||
[[1, 0, -1, -1, 2], [2, -2, 0, 0, 1], [2, 1, 0, 1, 1], [-1, -2, 0, 2, 0]] | ||
) | ||
# This is 4x3 | ||
self.input_tensor_flat = torch.Tensor( | ||
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]] | ||
) | ||
|
||
self.vq = Quantisation( | ||
num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim | ||
) | ||
self.vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) | ||
|
||
def test_quantised_output(self): | ||
actual = self.vq(self.encoded) | ||
# This is shape (2,5,3) | ||
expected = torch.Tensor( | ||
[ | ||
[ | ||
[2.0, 2.0, 1.0], | ||
[1.0, 1.0, 0.0], | ||
[0.0, 0.0, -1.0], | ||
[1.0, 1.0, -1.0], | ||
[1.0, 1.0, 2.0], | ||
], | ||
[ | ||
[2.0, 2.0, -1.0], | ||
[1.0, -2.0, -2.0], | ||
[0.0, 0.0, 0.0], | ||
[1.0, 0.0, 2.0], | ||
[1.0, 1.0, 0.0], | ||
], | ||
] | ||
) | ||
|
||
torch.testing.assert_close( | ||
actual, | ||
expected, | ||
msg=f"actual: {actual}, expected: {expected}", | ||
) | ||
|
||
def test_preprocess(self): | ||
encoded_flat, permuted_shape = self.vq._preprocess(self.encoded) | ||
|
||
expected_flat_shape = torch.tensor([6, 5]) | ||
expected_permuted_shape = torch.tensor([2, 3, 5]) | ||
|
||
actual_flat_shape = torch.tensor(encoded_flat.shape) | ||
actual_permuted_shape = torch.tensor(permuted_shape) | ||
|
||
assert torch.equal( | ||
actual_flat_shape, expected_flat_shape | ||
), f"actual flattened shape: {actual_flat_shape}, expected flattened shape: {expected_flat_shape}" | ||
|
||
assert torch.equal( | ||
actual_permuted_shape, expected_permuted_shape | ||
), f"actual permuted shape: {actual_permuted_shape}, expected permuted shape: {expected_permuted_shape}" | ||
|
||
def test_preprocess_channel_dim_assertion(self): | ||
with self.assertRaises(ValueError): | ||
encoded_flat, permuted_shape = self.vq._preprocess(self.encoded[:, :4, :]) | ||
|
||
def test_postprocess(self): | ||
quantised = self.vq._postprocess(self.input_tensor_flat, torch.Size([2, 2, 3])) | ||
actual_quantised_shape = torch.tensor(quantised.shape) | ||
expected_quantised_shape = torch.tensor([2, 3, 2]) | ||
|
||
assert torch.equal( | ||
actual_quantised_shape, expected_quantised_shape | ||
), f"actual quantised shape: {actual_quantised_shape}, expected quantised shape: {expected_quantised_shape}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from torchmultimodal.modules.losses.vqvae import CommitmentLoss | ||
|
||
|
||
class TestCommitment(unittest.TestCase): | ||
""" | ||
Test the Commitment Loss | ||
""" | ||
|
||
def setUp(self): | ||
self.quantized = torch.Tensor([[-1, 0, 1], [2, 1, 0]]) | ||
self.encoded = torch.Tensor([[-2, -1, 0], [0, 2, -2]]) | ||
self.commitment = CommitmentLoss() | ||
|
||
def test_loss_value(self): | ||
loss = self.commitment(self.quantized, self.encoded) | ||
|
||
actual = loss.item() | ||
expected = 2.0 | ||
|
||
torch.testing.assert_close( | ||
actual, | ||
expected, | ||
msg=f"actual: {actual}, expected: {expected}", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from torch import nn, Tensor, Size | ||
|
||
|
||
class Quantisation(nn.Module): | ||
"""Quantisation provides an embedding layer that takes in the output of an encoder | ||
and performs a nearest-neighbor lookup in the embedding space. | ||
Vector quantisation was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) | ||
to generate high-fidelity images, videos, and audio data. | ||
Args: | ||
num_embeddings (int): the number of vectors in the embedding space | ||
embedding_dim (int): the dimensionality of the embedding vectors | ||
Inputs: | ||
x (Tensor): Tensor containing a batch of encoder outputs. | ||
Expects dimensions to be batch x channel x n dims. | ||
""" | ||
|
||
def __init__(self, num_embeddings: int, embedding_dim: int): | ||
super().__init__() | ||
self.embedding_dim = embedding_dim | ||
self.num_embeddings = num_embeddings | ||
|
||
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) | ||
self.embedding.weight.data.uniform_( | ||
-1 / self.num_embeddings, 1 / self.num_embeddings | ||
) | ||
self.quantised_vectors = None | ||
|
||
def _preprocess(self, x: Tensor): | ||
# Rearrange from batch x channel x n dims to batch x n dims x channel | ||
new_dims = (0,) + tuple(range(2, len(x.shape))) + (1,) | ||
x_permuted = x.permute(new_dims).contiguous() | ||
permuted_shape = x_permuted.shape | ||
|
||
# Flatten input | ||
x_flat = x_permuted.view(-1, permuted_shape[-1]) | ||
# channel dimension should be embedding dim so that each element in encoder | ||
# output volume gets associated with single embedding vector | ||
if x_flat.shape[-1] != self.embedding_dim: | ||
raise ValueError( | ||
f"Expected {x_flat.shape[-1]} to be embedding size of {self.embedding_dim}" | ||
) | ||
|
||
return x_flat, permuted_shape | ||
|
||
def _postprocess(self, quantised_flat: Tensor, permuted_shape: Size): | ||
# Rearrange back to batch x channel x n dims | ||
num_dims = len(permuted_shape) | ||
quantised_permuted = quantised_flat.view(permuted_shape) | ||
old_dims = (0,) + (num_dims - 1,) + tuple(range(1, num_dims - 1)) | ||
quantised = quantised_permuted.permute(old_dims).contiguous() | ||
|
||
return quantised | ||
|
||
def quantise(self, x_flat: Tensor): | ||
# Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2 | ||
distances = torch.cdist(x_flat, self.embedding.weight, p=2.0) ** 2 | ||
|
||
# Encoding - select closest embedding vectors | ||
encoding_indices = torch.argmin(distances, dim=1) | ||
|
||
# Quantise | ||
quantised_flat = self.embedding(encoding_indices) | ||
|
||
# Straight through estimator | ||
quantised_flat = x_flat + (quantised_flat - x_flat).detach() | ||
|
||
return quantised_flat | ||
|
||
def forward(self, x: Tensor): | ||
# Reshape and flatten encoder output for quantisation | ||
x_flat, permuted_shape = self._preprocess(x) | ||
|
||
# Quantisation via nearest neighbor lookup | ||
quantised_flat = self.quantise(x_flat) | ||
|
||
# Reshape back to original dims | ||
quantised = self._postprocess(quantised_flat, permuted_shape) | ||
self.quantised_vectors = quantised | ||
|
||
return quantised | ||
|
||
def get_quantised_vectors(self): | ||
# Retrieve the previously quantised vectors without forward passing again | ||
if self.quantised_vectors is None: | ||
raise Exception( | ||
"quantisation has not yet been performed, please run a forward pass" | ||
) | ||
return self.quantised_vectors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any | ||
|
||
from torch import nn, Tensor | ||
from torch.nn import functional as F | ||
|
||
|
||
class CommitmentLoss(nn.Module): | ||
"""Commitment loss calculates the mean Euclidean distance between pairs of encoder output vectors | ||
and their corresponding quantised vectors. It encourages an encoder to generate outputs closer to an embedding. | ||
This is the beta in Eq. 3 of Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) | ||
Args: | ||
commitment_cost (float): multiplicative weight for the commitment loss value | ||
""" | ||
|
||
def __init__(self, commitment_cost: float = 1.0, **kwargs: Any): | ||
super().__init__() | ||
self.commitment_cost = commitment_cost | ||
|
||
def forward(self, quantised: Tensor, encoded: Tensor): | ||
# Quantised vectors must be detached because commitment loss only lets gradient flow through encoder output | ||
loss = F.mse_loss(quantised.detach(), encoded) * self.commitment_cost | ||
|
||
return loss |