Skip to content

Commit

Permalink
Added quantisation layer, commitment loss, and associated unit tests (f…
Browse files Browse the repository at this point in the history
…acebookresearch#29)

Summary:
Vector Quantized Variational Autoencoder (VQVAE) is widely used to generate multimodal data (audio, image, or video samples). A well-known model implementing this is OpenAI's DALL-E model, which can generate images from text captions. The core component of the VQVAE is the quantisation layer, which is a discrete embedding of vectors that codifies the outputs of an encoder. Here, this layer is implemented as its own module, along with commitment loss which is used to fine-tune the encoder.

Pull Request resolved: facebookresearch#29

Test Plan:
Unit tests were created and added in this PR. Testing was done with PyTest.

`pytest -vv test/modules/losses/test_commitment.py`

<img width="566" alt="image" src="https://user-images.githubusercontent.com/33648637/165840071-bac97cc4-1534-486c-a86d-1e65d50141cc.png">

`pytest -vv test/modules/layers/test_quantisation.py`

<img width="565" alt="image" src="https://user-images.githubusercontent.com/33648637/165840175-decd1098-0d5a-418c-8806-0bf535999229.png">

Reviewed By: langong347

Differential Revision: D36030926

Pulled By: RdoubleA

fbshipit-source-id: 6478df6d9cd3f07a48cb4a9ab492cde1679cf1b6
  • Loading branch information
RdoubleA authored and facebook-github-bot committed May 3, 2022
1 parent 28d222a commit 292219e
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 0 deletions.
100 changes: 100 additions & 0 deletions test/modules/layers/test_quantisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch import nn
from torchmultimodal.modules.layers.quantisation import Quantisation


class TestQuantisation(unittest.TestCase):
"""
Test the Quantisation class
"""

def setUp(self):
self.num_embeddings = 4
self.embedding_dim = 5

# This is 2x5x3
self.encoded = torch.Tensor(
[
[[-1, 0, 1], [2, 1, 0], [0, -1, -1], [0, 2, -1], [-2, -1, 1]],
[[2, 2, -1], [1, -1, -2], [0, 0, 0], [1, 2, 1], [1, 0, 0]],
]
)
# This is 4x5
self.embedding_weights = torch.Tensor(
[[1, 0, -1, -1, 2], [2, -2, 0, 0, 1], [2, 1, 0, 1, 1], [-1, -2, 0, 2, 0]]
)
# This is 4x3
self.input_tensor_flat = torch.Tensor(
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]
)

self.vq = Quantisation(
num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim
)
self.vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights)

def test_quantised_output(self):
actual = self.vq(self.encoded)
# This is shape (2,5,3)
expected = torch.Tensor(
[
[
[2.0, 2.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, 0.0, -1.0],
[1.0, 1.0, -1.0],
[1.0, 1.0, 2.0],
],
[
[2.0, 2.0, -1.0],
[1.0, -2.0, -2.0],
[0.0, 0.0, 0.0],
[1.0, 0.0, 2.0],
[1.0, 1.0, 0.0],
],
]
)

torch.testing.assert_close(
actual,
expected,
msg=f"actual: {actual}, expected: {expected}",
)

def test_preprocess(self):
encoded_flat, permuted_shape = self.vq._preprocess(self.encoded)

expected_flat_shape = torch.tensor([6, 5])
expected_permuted_shape = torch.tensor([2, 3, 5])

actual_flat_shape = torch.tensor(encoded_flat.shape)
actual_permuted_shape = torch.tensor(permuted_shape)

assert torch.equal(
actual_flat_shape, expected_flat_shape
), f"actual flattened shape: {actual_flat_shape}, expected flattened shape: {expected_flat_shape}"

assert torch.equal(
actual_permuted_shape, expected_permuted_shape
), f"actual permuted shape: {actual_permuted_shape}, expected permuted shape: {expected_permuted_shape}"

def test_preprocess_channel_dim_assertion(self):
with self.assertRaises(ValueError):
encoded_flat, permuted_shape = self.vq._preprocess(self.encoded[:, :4, :])

def test_postprocess(self):
quantised = self.vq._postprocess(self.input_tensor_flat, torch.Size([2, 2, 3]))
actual_quantised_shape = torch.tensor(quantised.shape)
expected_quantised_shape = torch.tensor([2, 3, 2])

assert torch.equal(
actual_quantised_shape, expected_quantised_shape
), f"actual quantised shape: {actual_quantised_shape}, expected quantised shape: {expected_quantised_shape}"
33 changes: 33 additions & 0 deletions test/modules/losses/test_commitment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torchmultimodal.modules.losses.vqvae import CommitmentLoss


class TestCommitment(unittest.TestCase):
"""
Test the Commitment Loss
"""

def setUp(self):
self.quantized = torch.Tensor([[-1, 0, 1], [2, 1, 0]])
self.encoded = torch.Tensor([[-2, -1, 0], [0, 2, -2]])
self.commitment = CommitmentLoss()

def test_loss_value(self):
loss = self.commitment(self.quantized, self.encoded)

actual = loss.item()
expected = 2.0

torch.testing.assert_close(
actual,
expected,
msg=f"actual: {actual}, expected: {expected}",
)
98 changes: 98 additions & 0 deletions torchmultimodal/modules/layers/quantisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn, Tensor, Size


class Quantisation(nn.Module):
"""Quantisation provides an embedding layer that takes in the output of an encoder
and performs a nearest-neighbor lookup in the embedding space.
Vector quantisation was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
to generate high-fidelity images, videos, and audio data.
Args:
num_embeddings (int): the number of vectors in the embedding space
embedding_dim (int): the dimensionality of the embedding vectors
Inputs:
x (Tensor): Tensor containing a batch of encoder outputs.
Expects dimensions to be batch x channel x n dims.
"""

def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings

self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.embedding.weight.data.uniform_(
-1 / self.num_embeddings, 1 / self.num_embeddings
)
self.quantised_vectors = None

def _preprocess(self, x: Tensor):
# Rearrange from batch x channel x n dims to batch x n dims x channel
new_dims = (0,) + tuple(range(2, len(x.shape))) + (1,)
x_permuted = x.permute(new_dims).contiguous()
permuted_shape = x_permuted.shape

# Flatten input
x_flat = x_permuted.view(-1, permuted_shape[-1])
# channel dimension should be embedding dim so that each element in encoder
# output volume gets associated with single embedding vector
if x_flat.shape[-1] != self.embedding_dim:
raise ValueError(
f"Expected {x_flat.shape[-1]} to be embedding size of {self.embedding_dim}"
)

return x_flat, permuted_shape

def _postprocess(self, quantised_flat: Tensor, permuted_shape: Size):
# Rearrange back to batch x channel x n dims
num_dims = len(permuted_shape)
quantised_permuted = quantised_flat.view(permuted_shape)
old_dims = (0,) + (num_dims - 1,) + tuple(range(1, num_dims - 1))
quantised = quantised_permuted.permute(old_dims).contiguous()

return quantised

def quantise(self, x_flat: Tensor):
# Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2
distances = torch.cdist(x_flat, self.embedding.weight, p=2.0) ** 2

# Encoding - select closest embedding vectors
encoding_indices = torch.argmin(distances, dim=1)

# Quantise
quantised_flat = self.embedding(encoding_indices)

# Straight through estimator
quantised_flat = x_flat + (quantised_flat - x_flat).detach()

return quantised_flat

def forward(self, x: Tensor):
# Reshape and flatten encoder output for quantisation
x_flat, permuted_shape = self._preprocess(x)

# Quantisation via nearest neighbor lookup
quantised_flat = self.quantise(x_flat)

# Reshape back to original dims
quantised = self._postprocess(quantised_flat, permuted_shape)
self.quantised_vectors = quantised

return quantised

def get_quantised_vectors(self):
# Retrieve the previously quantised vectors without forward passing again
if self.quantised_vectors is None:
raise Exception(
"quantisation has not yet been performed, please run a forward pass"
)
return self.quantised_vectors
30 changes: 30 additions & 0 deletions torchmultimodal/modules/losses/vqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

from torch import nn, Tensor
from torch.nn import functional as F


class CommitmentLoss(nn.Module):
"""Commitment loss calculates the mean Euclidean distance between pairs of encoder output vectors
and their corresponding quantised vectors. It encourages an encoder to generate outputs closer to an embedding.
This is the beta in Eq. 3 of Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
Args:
commitment_cost (float): multiplicative weight for the commitment loss value
"""

def __init__(self, commitment_cost: float = 1.0, **kwargs: Any):
super().__init__()
self.commitment_cost = commitment_cost

def forward(self, quantised: Tensor, encoded: Tensor):
# Quantised vectors must be detached because commitment loss only lets gradient flow through encoder output
loss = F.mse_loss(quantised.detach(), encoded) * self.commitment_cost

return loss

0 comments on commit 292219e

Please sign in to comment.