Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added quantisation layer, commitment loss, and associated unit tests #29

Closed
wants to merge 7 commits into from
100 changes: 100 additions & 0 deletions test/modules/layers/test_quantisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch import nn
from torchmultimodal.modules.layers.quantisation import Quantisation


class TestQuantisation(unittest.TestCase):
"""
Test the Quantisation class
"""

def setUp(self):
self.num_embeddings = 4
self.embedding_dim = 5

# This is 2x5x3
self.encoded = torch.Tensor(
[
[[-1, 0, 1], [2, 1, 0], [0, -1, -1], [0, 2, -1], [-2, -1, 1]],
[[2, 2, -1], [1, -1, -2], [0, 0, 0], [1, 2, 1], [1, 0, 0]],
]
)
# This is 4x5
self.embedding_weights = torch.Tensor(
[[1, 0, -1, -1, 2], [2, -2, 0, 0, 1], [2, 1, 0, 1, 1], [-1, -2, 0, 2, 0]]
)
# This is 4x3
self.input_tensor_flat = torch.Tensor(
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]
)

self.vq = Quantisation(
num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim
)
self.vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights)

def test_quantised_output(self):
actual = self.vq(self.encoded)
# This is shape (2,5,3)
expected = torch.Tensor(
[
[
[2.0, 2.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, 0.0, -1.0],
[1.0, 1.0, -1.0],
[1.0, 1.0, 2.0],
],
[
[2.0, 2.0, -1.0],
[1.0, -2.0, -2.0],
[0.0, 0.0, 0.0],
[1.0, 0.0, 2.0],
[1.0, 1.0, 0.0],
],
]
)

torch.testing.assert_close(
actual,
expected,
msg=f"actual: {actual}, expected: {expected}",
)

def test_preprocess(self):
encoded_flat, permuted_shape = self.vq._preprocess(self.encoded)

expected_flat_shape = torch.tensor([6, 5])
expected_permuted_shape = torch.tensor([2, 3, 5])

actual_flat_shape = torch.tensor(encoded_flat.shape)
actual_permuted_shape = torch.tensor(permuted_shape)

assert torch.equal(
actual_flat_shape, expected_flat_shape
), f"actual flattened shape: {actual_flat_shape}, expected flattened shape: {expected_flat_shape}"

assert torch.equal(
actual_permuted_shape, expected_permuted_shape
), f"actual permuted shape: {actual_permuted_shape}, expected permuted shape: {expected_permuted_shape}"

def test_preprocess_channel_dim_assertion(self):
with self.assertRaises(ValueError):
encoded_flat, permuted_shape = self.vq._preprocess(self.encoded[:, :4, :])

def test_postprocess(self):
quantised = self.vq._postprocess(self.input_tensor_flat, torch.Size([2, 2, 3]))
actual_quantised_shape = torch.tensor(quantised.shape)
expected_quantised_shape = torch.tensor([2, 3, 2])

assert torch.equal(
actual_quantised_shape, expected_quantised_shape
), f"actual quantised shape: {actual_quantised_shape}, expected quantised shape: {expected_quantised_shape}"
33 changes: 33 additions & 0 deletions test/modules/losses/test_commitment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torchmultimodal.modules.losses.vqvae import CommitmentLoss


class TestCommitment(unittest.TestCase):
"""
Test the Commitment Loss
"""

def setUp(self):
self.quantized = torch.Tensor([[-1, 0, 1], [2, 1, 0]])
self.encoded = torch.Tensor([[-2, -1, 0], [0, 2, -2]])
self.commitment = CommitmentLoss()

def test_loss_value(self):
loss = self.commitment(self.quantized, self.encoded)

actual = loss.item()
expected = 2.0

torch.testing.assert_close(
actual,
expected,
msg=f"actual: {actual}, expected: {expected}",
)
98 changes: 98 additions & 0 deletions torchmultimodal/modules/layers/quantisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn, Tensor, Size


class Quantisation(nn.Module):
"""Quantisation provides an embedding layer that takes in the output of an encoder
and performs a nearest-neighbor lookup in the embedding space.

Vector quantisation was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
to generate high-fidelity images, videos, and audio data.

Args:
num_embeddings (int): the number of vectors in the embedding space
embedding_dim (int): the dimensionality of the embedding vectors

Inputs:
x (Tensor): Tensor containing a batch of encoder outputs.
Expects dimensions to be batch x channel x n dims.
"""
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings

self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.embedding.weight.data.uniform_(
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
-1 / self.num_embeddings, 1 / self.num_embeddings
)
self.quantised_vectors = None
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

def _preprocess(self, x: Tensor):
# Rearrange from batch x channel x n dims to batch x n dims x channel
new_dims = (0,) + tuple(range(2, len(x.shape))) + (1,)
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
x_permuted = x.permute(new_dims).contiguous()
permuted_shape = x_permuted.shape

# Flatten input
x_flat = x_permuted.view(-1, permuted_shape[-1])
# channel dimension should be embedding dim so that each element in encoder
# output volume gets associated with single embedding vector
if x_flat.shape[-1] != self.embedding_dim:
raise ValueError(
f"Expected {x_flat.shape[-1]} to be embedding size of {self.embedding_dim}"
)

return x_flat, permuted_shape

def _postprocess(self, quantised_flat: Tensor, permuted_shape: Size):
# Rearrange back to batch x channel x n dims
num_dims = len(permuted_shape)
quantised_permuted = quantised_flat.view(permuted_shape)
old_dims = (0,) + (num_dims - 1,) + tuple(range(1, num_dims - 1))
quantised = quantised_permuted.permute(old_dims).contiguous()

return quantised

def quantise(self, x_flat: Tensor):
# Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2
distances = torch.cdist(x_flat, self.embedding.weight, p=2.0) ** 2

# Encoding - select closest embedding vectors
encoding_indices = torch.argmin(distances, dim=1)

# Quantise
quantised_flat = self.embedding(encoding_indices)

# Straight through estimator
quantised_flat = x_flat + (quantised_flat - x_flat).detach()

return quantised_flat

def forward(self, x: Tensor):
# Reshape and flatten encoder output for quantisation
x_flat, permuted_shape = self._preprocess(x)

# Quantisation via nearest neighbor lookup
quantised_flat = self.quantise(x_flat)

# Reshape back to original dims
quantised = self._postprocess(quantised_flat, permuted_shape)
self.quantised_vectors = quantised
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

return quantised

def get_quantised_vectors(self):
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
# Retrieve the previously quantised vectors without forward passing again
if self.quantised_vectors is None:
raise Exception(
"quantisation has not yet been performed, please run a forward pass"
)
return self.quantised_vectors
30 changes: 30 additions & 0 deletions torchmultimodal/modules/losses/vqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

from torch import nn, Tensor
from torch.nn import functional as F


class CommitmentLoss(nn.Module):
"""Commitment loss calculates the mean Euclidean distance between pairs of encoder output vectors
and their corresponding quantised vectors. It encourages an encoder to generate outputs closer to an embedding.
This is the beta in Eq. 3 of Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)

Args:
commitment_cost (float): multiplicative weight for the commitment loss value
"""

def __init__(self, commitment_cost: float = 1.0, **kwargs: Any):
super().__init__()
self.commitment_cost = commitment_cost

def forward(self, quantised: Tensor, encoded: Tensor):
# Quantised vectors must be detached because commitment loss only lets gradient flow through encoder output
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
loss = F.mse_loss(quantised.detach(), encoded) * self.commitment_cost

return loss