From 669513bbc24f20ebdb89531dce7b76d0c65e94fb Mon Sep 17 00:00:00 2001 From: Rafi Ayub Date: Tue, 26 Apr 2022 16:00:23 -0700 Subject: [PATCH 1/7] added quantisation layer, commitment loss, and associated unit tests --- test/modules/layers/test_quantisation.py | 119 ++++++++++++++++++ test/modules/losses/test_commitment.py | 35 ++++++ .../modules/layers/quantisation.py | 61 +++++++++ torchmultimodal/modules/losses/vqvae.py | 28 +++++ torchmultimodal/utils/preprocess.py | 48 +++++++ 5 files changed, 291 insertions(+) create mode 100644 test/modules/layers/test_quantisation.py create mode 100644 test/modules/losses/test_commitment.py create mode 100644 torchmultimodal/modules/layers/quantisation.py create mode 100644 torchmultimodal/modules/losses/vqvae.py create mode 100644 torchmultimodal/utils/preprocess.py diff --git a/test/modules/layers/test_quantisation.py b/test/modules/layers/test_quantisation.py new file mode 100644 index 00000000..78fe4481 --- /dev/null +++ b/test/modules/layers/test_quantisation.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch import nn +from torchmultimodal.modules.layers.quantisation import Quantisation +from torchmultimodal.utils.preprocess import ( + flatten_to_channel_vectors, + reshape_from_channel_vectors, +) + + +class TestQuantisation(unittest.TestCase): + """ + Test the Quantisation class + """ + + def setUp(self): + torch.set_printoptions(precision=10) + torch.manual_seed(4) + self.num_embeddings = 4 + self.embedding_dim = 5 + self.encoded = torch.randn((2, self.embedding_dim, 3, 3)) + self.embedding_weights = torch.randn((self.num_embeddings, self.embedding_dim)) + + def test_quantised_output(self): + vq = Quantisation( + num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim + ) + vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) + x, permuted_shape = flatten_to_channel_vectors(self.encoded, 1) + output = vq(x) + actual = reshape_from_channel_vectors(output, permuted_shape, 1) + + # This is shape (2,5,3,3) + expected = torch.Tensor( + [ + [ + [ + [-1.1265823841, -0.1376807690, 0.7218121886], + [-0.1376807690, 0.7218121886, 0.7218121886], + [-1.1265823841, -0.1376807690, -1.1265823841], + ], + [ + [-0.5252661109, 0.3244057596, -0.8940765262], + [0.3244057596, -0.8940765262, -0.8940765262], + [-0.5252661109, 0.3244057596, -0.5252661109], + ], + [ + [-0.9950634241, -1.0523844957, 1.7175949812], + [-1.0523844957, 1.7175949812, 1.7175949812], + [-0.9950634241, -1.0523844957, -0.9950634241], + ], + [ + [0.2679379284, -0.4480970800, -0.3190571964], + [-0.4480970800, -0.3190571964, -0.3190571964], + [0.2679379284, -0.4480970800, 0.2679379284], + ], + [ + [-0.6253433824, -0.5198931098, -0.8529881239], + [-0.5198931098, -0.8529881239, -0.8529881239], + [-0.6253433824, -0.5198931098, -0.6253433824], + ], + ], + [ + [ + [-1.6703201532, -1.1265823841, -0.1376807690], + [-1.1265823841, 0.7218121886, -1.1265823841], + [-0.1376807690, -0.1376807690, -0.1376807690], + ], + [ + [0.8635767698, -0.5252661109, 0.3244057596], + [-0.5252661109, -0.8940765262, -0.5252661109], + [0.3244057596, 0.3244057596, 0.3244057596], + ], + [ + [-1.5300362110, -0.9950634241, -1.0523844957], + [-0.9950634241, 1.7175949812, -0.9950634241], + [-1.0523844957, -1.0523844957, -1.0523844957], + ], + [ + [0.5375117064, 0.2679379284, -0.4480970800], + [0.2679379284, -0.3190571964, 0.2679379284], + [-0.4480970800, -0.4480970800, -0.4480970800], + ], + [ + [-1.6273639202, -0.6253433824, -0.5198931098], + [-0.6253433824, -0.8529881239, -0.6253433824], + [-0.5198931098, -0.5198931098, -0.5198931098], + ], + ], + ] + ) + + torch.testing.assert_close( + actual, + expected, + msg=f"actual: {actual}, expected: {expected}", + ) + + def test_quantised_shape(self): + vq = Quantisation( + num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim + ) + vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) + x, permuted_shape = flatten_to_channel_vectors(self.encoded, 1) + output = vq(x) + output = reshape_from_channel_vectors(output, permuted_shape, 1) + actual = torch.tensor(output.shape) + expected = torch.tensor([2, 5, 3, 3]) + + assert torch.equal( + actual, expected + ), f"actual shape: {actual}, expected shape: {expected}" diff --git a/test/modules/losses/test_commitment.py b/test/modules/losses/test_commitment.py new file mode 100644 index 00000000..2bb93c3e --- /dev/null +++ b/test/modules/losses/test_commitment.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torchmultimodal.modules.losses.vqvae import CommitmentLoss + + +class TestCommitment(unittest.TestCase): + """ + Test the Commitment Loss + """ + + def setUp(self): + torch.set_printoptions(precision=10) + torch.manual_seed(4) + self.quantized = torch.randn((2, 3)) + self.encoded = torch.randn((2, 3)) + + def test_loss_value(self): + commitment = CommitmentLoss() + loss = commitment(self.quantized, self.encoded) + + actual = loss.loss.item() + expected = 1.2070025206 + + torch.testing.assert_close( + actual, + expected, + msg=f"actual: {actual}, expected: {expected}", + ) diff --git a/torchmultimodal/modules/layers/quantisation.py b/torchmultimodal/modules/layers/quantisation.py new file mode 100644 index 00000000..2ee9e130 --- /dev/null +++ b/torchmultimodal/modules/layers/quantisation.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + + +class Quantisation(nn.Module): + """ + Embedding layer that takes in a collection of flattened vectors and finds closest embedding vectors + to each flattened vector and outputs those selected embedding vectors. Also known as vector quantisation. + """ + + def __init__(self, num_embeddings, embedding_dim): + super().__init__() + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + self.embedding.weight.data.normal_() + self.quantised_vectors = None + + def forward(self, x_flat): + x_shape = x_flat.shape + # channel dimension should be embedding dim so that each element in encoder + # output volume gets associated with single embedding vector + assert ( + x_shape[-1] == self.embedding_dim + ), f"Expected {x_shape[-1]} to be embedding size of {self.embedding_dim}" + + # Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2 + w_t = self.embedding.weight.t() + distances = ( + torch.sum(x_flat ** 2, dim=1, keepdim=True) + - 2 * torch.matmul(x_flat, w_t) + + torch.sum(w_t ** 2, dim=0, keepdim=True) + ) + + # Encoding - select closest embedding vectors + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros( + encoding_indices.shape[0], self.num_embeddings, device=x_flat.device + ) + encodings.scatter_(1, encoding_indices, 1) + + # Quantise + quantised_flattened = torch.matmul(encodings, self.embedding.weight) + self.quantised_vectors = quantised_flattened + + return quantised_flattened + + def get_quantised_vectors(self): + # Retrieve the previously quantised vectors without forward passing again + if self.quantised_vectors is None: + raise Exception( + "quantisation has not yet been performed, please run a forward pass" + ) + return self.quantised_vectors diff --git a/torchmultimodal/modules/losses/vqvae.py b/torchmultimodal/modules/losses/vqvae.py new file mode 100644 index 00000000..e3490325 --- /dev/null +++ b/torchmultimodal/modules/losses/vqvae.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any + +from torch import nn, Tensor + + +@dataclass +class CommitmentLossOutput: + loss: Tensor + + +class CommitmentLoss(nn.Module): + def __init__(self, commitment_cost: float = 1.0, **kwargs: Any): + super().__init__() + self.mse_loss = nn.MSELoss(reduction="mean") + self.commitment_cost = commitment_cost + + def forward(self, quantised: Tensor, encoded: Tensor): + # Quantised vectors must be detached because commitment loss only lets gradient flow through encoder output + loss = self.mse_loss(quantised.detach(), encoded) * self.commitment_cost + + return CommitmentLossOutput(loss=loss) diff --git a/torchmultimodal/utils/preprocess.py b/torchmultimodal/utils/preprocess.py new file mode 100644 index 00000000..376cff3d --- /dev/null +++ b/torchmultimodal/utils/preprocess.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# +# preprocess.py +# +# This file contains utility functions to process, reshape, etc. Tensors +# +# + + +def flatten_to_channel_vectors(x, channel_dim): + """ + Takes an input tensor and flattens across all dimensions except the specified dimension + An example is flattening an encoder output volume of BxCxHxW to (B*H*W)xC for a VQVAE model + """ + + # Move channel dim to end + new_dims = tuple( + [i for i in range(len(x.shape)) if i != channel_dim] + [channel_dim] + ) + x = x.permute(new_dims).contiguous() + permuted_shape = x.shape + + # Flatten input + x_flat = x.view(-1, permuted_shape[-1]) + + return x_flat, permuted_shape + + +def reshape_from_channel_vectors(x_flat, permuted_shape, orig_channel_dim): + """ + The inverse of flatten_to_channel_vectors + """ + # Reshape flattened vectors to permuted shape + x_out = x_flat.view(permuted_shape) + + # Move channel dim (last) back to original spot + old_dims = list(range(len(permuted_shape))) + old_dims.pop(-1) + old_dims.insert(orig_channel_dim, len(permuted_shape) - 1) + x_out = x_out.permute(old_dims).contiguous() + + return x_out From 7e8c0a2823f1790cc251329866be5fd7402234b0 Mon Sep 17 00:00:00 2001 From: Rafi Ayub Date: Wed, 27 Apr 2022 12:09:39 -0700 Subject: [PATCH 2/7] docstrings, type annotations, simplified dist calc, neighbor lookup, added straight thru estimator --- test/modules/layers/test_quantisation.py | 3 +- test/modules/losses/test_commitment.py | 3 +- .../modules/layers/quantisation.py | 39 ++++++++++--------- torchmultimodal/modules/losses/vqvae.py | 18 +++++---- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/test/modules/layers/test_quantisation.py b/test/modules/layers/test_quantisation.py index 78fe4481..7ffd714b 100644 --- a/test/modules/layers/test_quantisation.py +++ b/test/modules/layers/test_quantisation.py @@ -9,6 +9,7 @@ import torch from torch import nn from torchmultimodal.modules.layers.quantisation import Quantisation +from torchmultimodal.test.test_utils import set_rng_seed from torchmultimodal.utils.preprocess import ( flatten_to_channel_vectors, reshape_from_channel_vectors, @@ -22,7 +23,7 @@ class TestQuantisation(unittest.TestCase): def setUp(self): torch.set_printoptions(precision=10) - torch.manual_seed(4) + set_rng_seed(4) self.num_embeddings = 4 self.embedding_dim = 5 self.encoded = torch.randn((2, self.embedding_dim, 3, 3)) diff --git a/test/modules/losses/test_commitment.py b/test/modules/losses/test_commitment.py index 2bb93c3e..76f3a6d3 100644 --- a/test/modules/losses/test_commitment.py +++ b/test/modules/losses/test_commitment.py @@ -8,6 +8,7 @@ import torch from torchmultimodal.modules.losses.vqvae import CommitmentLoss +from torchmultimodal.test.test_utils import set_rng_seed class TestCommitment(unittest.TestCase): @@ -17,7 +18,7 @@ class TestCommitment(unittest.TestCase): def setUp(self): torch.set_printoptions(precision=10) - torch.manual_seed(4) + set_rng_seed(4) self.quantized = torch.randn((2, 3)) self.encoded = torch.randn((2, 3)) diff --git a/torchmultimodal/modules/layers/quantisation.py b/torchmultimodal/modules/layers/quantisation.py index 2ee9e130..b6226b14 100644 --- a/torchmultimodal/modules/layers/quantisation.py +++ b/torchmultimodal/modules/layers/quantisation.py @@ -5,25 +5,33 @@ # LICENSE file in the root directory of this source tree. import torch -from torch import nn +from torch import nn, Tensor class Quantisation(nn.Module): - """ - Embedding layer that takes in a collection of flattened vectors and finds closest embedding vectors - to each flattened vector and outputs those selected embedding vectors. Also known as vector quantisation. + """Quantisation provides an embedding layer that takes in a collection of flattened vectors, usually the + output of an encoder architecture, and performs a nearest-neighbor lookup in the embedding space. + + Vector quantisation was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) to generate high-fidelity + images, videos, and audio data. + + Args: + num_embeddings (int): the number of vectors in the embedding space + embedding_dim (int): the dimensionality of the embedding vectors """ - def __init__(self, num_embeddings, embedding_dim): + def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__() self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) - self.embedding.weight.data.normal_() + self.embedding.weight.data.uniform_( + -1 / self.num_embeddings, 1 / self.num_embeddings + ) self.quantised_vectors = None - def forward(self, x_flat): + def forward(self, x_flat: Tensor): x_shape = x_flat.shape # channel dimension should be embedding dim so that each element in encoder # output volume gets associated with single embedding vector @@ -33,21 +41,16 @@ def forward(self, x_flat): # Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2 w_t = self.embedding.weight.t() - distances = ( - torch.sum(x_flat ** 2, dim=1, keepdim=True) - - 2 * torch.matmul(x_flat, w_t) - + torch.sum(w_t ** 2, dim=0, keepdim=True) - ) + distances = torch.cdist(x_flat, w_t, p=2.0) ** 2 # Encoding - select closest embedding vectors - encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) - encodings = torch.zeros( - encoding_indices.shape[0], self.num_embeddings, device=x_flat.device - ) - encodings.scatter_(1, encoding_indices, 1) + encoding_indices = torch.argmin(distances, dim=1) # Quantise - quantised_flattened = torch.matmul(encodings, self.embedding.weight) + quantised_flattened = self.embedding(encoding_indices) + + # Straight through estimator + quantised_flattened = x_flat + (quantised_flattened - x_flat).detach() self.quantised_vectors = quantised_flattened return quantised_flattened diff --git a/torchmultimodal/modules/losses/vqvae.py b/torchmultimodal/modules/losses/vqvae.py index e3490325..dc2fa86e 100644 --- a/torchmultimodal/modules/losses/vqvae.py +++ b/torchmultimodal/modules/losses/vqvae.py @@ -4,25 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass from typing import Any from torch import nn, Tensor +from torch.nn import functional as F -@dataclass -class CommitmentLossOutput: - loss: Tensor +class CommitmentLoss(nn.Module): + """Commitment loss calculates the mean Euclidean distance between pairs of encoder output vectors + and their corresponding quantised vectors. It encourages an encoder to generate outputs closer to an embedding. + This is the beta in Eq. 3 of Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) + Args: + commitment_cost (float): multiplicative weight for the commitment loss value + """ -class CommitmentLoss(nn.Module): def __init__(self, commitment_cost: float = 1.0, **kwargs: Any): super().__init__() - self.mse_loss = nn.MSELoss(reduction="mean") self.commitment_cost = commitment_cost def forward(self, quantised: Tensor, encoded: Tensor): # Quantised vectors must be detached because commitment loss only lets gradient flow through encoder output - loss = self.mse_loss(quantised.detach(), encoded) * self.commitment_cost + loss = F.mse_loss(quantised.detach(), encoded) * self.commitment_cost - return CommitmentLossOutput(loss=loss) + return loss From c32ae3e932f9d7eec97be71f07953b186ca4aba9 Mon Sep 17 00:00:00 2001 From: Rafi Ayub Date: Wed, 27 Apr 2022 13:20:13 -0700 Subject: [PATCH 3/7] fixes from unit testing --- test/modules/layers/test_quantisation.py | 2 +- test/modules/losses/test_commitment.py | 4 ++-- torchmultimodal/modules/layers/quantisation.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/modules/layers/test_quantisation.py b/test/modules/layers/test_quantisation.py index 7ffd714b..1ef7e8c8 100644 --- a/test/modules/layers/test_quantisation.py +++ b/test/modules/layers/test_quantisation.py @@ -7,9 +7,9 @@ import unittest import torch +from test.test_utils import set_rng_seed from torch import nn from torchmultimodal.modules.layers.quantisation import Quantisation -from torchmultimodal.test.test_utils import set_rng_seed from torchmultimodal.utils.preprocess import ( flatten_to_channel_vectors, reshape_from_channel_vectors, diff --git a/test/modules/losses/test_commitment.py b/test/modules/losses/test_commitment.py index 76f3a6d3..44248eaa 100644 --- a/test/modules/losses/test_commitment.py +++ b/test/modules/losses/test_commitment.py @@ -7,8 +7,8 @@ import unittest import torch +from test.test_utils import set_rng_seed from torchmultimodal.modules.losses.vqvae import CommitmentLoss -from torchmultimodal.test.test_utils import set_rng_seed class TestCommitment(unittest.TestCase): @@ -26,7 +26,7 @@ def test_loss_value(self): commitment = CommitmentLoss() loss = commitment(self.quantized, self.encoded) - actual = loss.loss.item() + actual = loss.item() expected = 1.2070025206 torch.testing.assert_close( diff --git a/torchmultimodal/modules/layers/quantisation.py b/torchmultimodal/modules/layers/quantisation.py index b6226b14..a415b8e3 100644 --- a/torchmultimodal/modules/layers/quantisation.py +++ b/torchmultimodal/modules/layers/quantisation.py @@ -40,8 +40,7 @@ def forward(self, x_flat: Tensor): ), f"Expected {x_shape[-1]} to be embedding size of {self.embedding_dim}" # Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2 - w_t = self.embedding.weight.t() - distances = torch.cdist(x_flat, w_t, p=2.0) ** 2 + distances = torch.cdist(x_flat, self.embedding.weight, p=2.0) ** 2 # Encoding - select closest embedding vectors encoding_indices = torch.argmin(distances, dim=1) From 00b88e929694206756470412f4df5e4cead91df0 Mon Sep 17 00:00:00 2001 From: Rafi Ayub Date: Wed, 27 Apr 2022 16:02:08 -0700 Subject: [PATCH 4/7] refactors utils in preprocess.py inside of quantisation forward --- test/modules/layers/test_quantisation.py | 12 +---- .../modules/layers/quantisation.py | 26 ++++++---- torchmultimodal/utils/preprocess.py | 48 ------------------- 3 files changed, 20 insertions(+), 66 deletions(-) delete mode 100644 torchmultimodal/utils/preprocess.py diff --git a/test/modules/layers/test_quantisation.py b/test/modules/layers/test_quantisation.py index 1ef7e8c8..bc0752d3 100644 --- a/test/modules/layers/test_quantisation.py +++ b/test/modules/layers/test_quantisation.py @@ -10,10 +10,6 @@ from test.test_utils import set_rng_seed from torch import nn from torchmultimodal.modules.layers.quantisation import Quantisation -from torchmultimodal.utils.preprocess import ( - flatten_to_channel_vectors, - reshape_from_channel_vectors, -) class TestQuantisation(unittest.TestCase): @@ -34,9 +30,7 @@ def test_quantised_output(self): num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim ) vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) - x, permuted_shape = flatten_to_channel_vectors(self.encoded, 1) - output = vq(x) - actual = reshape_from_channel_vectors(output, permuted_shape, 1) + actual = vq(self.encoded) # This is shape (2,5,3,3) expected = torch.Tensor( @@ -109,9 +103,7 @@ def test_quantised_shape(self): num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim ) vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) - x, permuted_shape = flatten_to_channel_vectors(self.encoded, 1) - output = vq(x) - output = reshape_from_channel_vectors(output, permuted_shape, 1) + output = vq(self.encoded) actual = torch.tensor(output.shape) expected = torch.tensor([2, 5, 3, 3]) diff --git a/torchmultimodal/modules/layers/quantisation.py b/torchmultimodal/modules/layers/quantisation.py index a415b8e3..220276d8 100644 --- a/torchmultimodal/modules/layers/quantisation.py +++ b/torchmultimodal/modules/layers/quantisation.py @@ -31,13 +31,19 @@ def __init__(self, num_embeddings: int, embedding_dim: int): ) self.quantised_vectors = None - def forward(self, x_flat: Tensor): - x_shape = x_flat.shape + def forward(self, x: Tensor): + # Rearrange from batch x channel x n dims to batch x n dims x channel + new_dims = (0,) + tuple(range(2, len(x.shape))) + (1,) + x_permuted = x.permute(new_dims).contiguous() + permuted_shape = x_permuted.shape + + # Flatten input + x_flat = x_permuted.view(-1, permuted_shape[-1]) # channel dimension should be embedding dim so that each element in encoder # output volume gets associated with single embedding vector assert ( - x_shape[-1] == self.embedding_dim - ), f"Expected {x_shape[-1]} to be embedding size of {self.embedding_dim}" + x_flat.shape[-1] == self.embedding_dim + ), f"Expected {x_flat.shape[-1]} to be embedding size of {self.embedding_dim}" # Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2 distances = torch.cdist(x_flat, self.embedding.weight, p=2.0) ** 2 @@ -46,13 +52,17 @@ def forward(self, x_flat: Tensor): encoding_indices = torch.argmin(distances, dim=1) # Quantise - quantised_flattened = self.embedding(encoding_indices) + quantised_permuted = self.embedding(encoding_indices).view(permuted_shape) # Straight through estimator - quantised_flattened = x_flat + (quantised_flattened - x_flat).detach() - self.quantised_vectors = quantised_flattened + quantised_permuted = x_permuted + (quantised_permuted - x_permuted).detach() + + # Rearrange back to batch x channel x n dims + old_dims = (0,) + (len(x.shape) - 1,) + tuple(range(1, len(x.shape) - 1)) + quantised = quantised_permuted.permute(old_dims).contiguous() + self.quantised_vectors = quantised - return quantised_flattened + return quantised def get_quantised_vectors(self): # Retrieve the previously quantised vectors without forward passing again diff --git a/torchmultimodal/utils/preprocess.py b/torchmultimodal/utils/preprocess.py deleted file mode 100644 index 376cff3d..00000000 --- a/torchmultimodal/utils/preprocess.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -# -# preprocess.py -# -# This file contains utility functions to process, reshape, etc. Tensors -# -# - - -def flatten_to_channel_vectors(x, channel_dim): - """ - Takes an input tensor and flattens across all dimensions except the specified dimension - An example is flattening an encoder output volume of BxCxHxW to (B*H*W)xC for a VQVAE model - """ - - # Move channel dim to end - new_dims = tuple( - [i for i in range(len(x.shape)) if i != channel_dim] + [channel_dim] - ) - x = x.permute(new_dims).contiguous() - permuted_shape = x.shape - - # Flatten input - x_flat = x.view(-1, permuted_shape[-1]) - - return x_flat, permuted_shape - - -def reshape_from_channel_vectors(x_flat, permuted_shape, orig_channel_dim): - """ - The inverse of flatten_to_channel_vectors - """ - # Reshape flattened vectors to permuted shape - x_out = x_flat.view(permuted_shape) - - # Move channel dim (last) back to original spot - old_dims = list(range(len(permuted_shape))) - old_dims.pop(-1) - old_dims.insert(orig_channel_dim, len(permuted_shape) - 1) - x_out = x_out.permute(old_dims).contiguous() - - return x_out From 135cd76e0ef4dd468cb238be564538f1c37af99c Mon Sep 17 00:00:00 2001 From: Rafi Ayub Date: Thu, 28 Apr 2022 11:35:20 -0700 Subject: [PATCH 5/7] refactored forward of Quantisation, updated unit tests with manual inputs --- test/modules/layers/test_quantisation.py | 120 ++++++++++-------- .../modules/layers/quantisation.py | 45 +++++-- 2 files changed, 100 insertions(+), 65 deletions(-) diff --git a/test/modules/layers/test_quantisation.py b/test/modules/layers/test_quantisation.py index bc0752d3..d1f0259c 100644 --- a/test/modules/layers/test_quantisation.py +++ b/test/modules/layers/test_quantisation.py @@ -22,8 +22,21 @@ def setUp(self): set_rng_seed(4) self.num_embeddings = 4 self.embedding_dim = 5 - self.encoded = torch.randn((2, self.embedding_dim, 3, 3)) - self.embedding_weights = torch.randn((self.num_embeddings, self.embedding_dim)) + # This is 2x5x3 + self.encoded = torch.Tensor( + [ + [[-1, 0, 1], [2, 1, 0], [0, -1, -1], [0, 2, -1], [-2, -1, 1]], + [[2, 2, -1], [1, -1, -2], [0, 0, 0], [1, 2, 1], [1, 0, 0]], + ] + ) + # This is 4x5 + self.embedding_weights = torch.Tensor( + [[1, 0, -1, -1, 2], [2, -2, 0, 0, 1], [2, 1, 0, 1, 1], [-1, -2, 0, 2, 0]] + ) + # This is 4x3 + self.test_tensor_flat = torch.Tensor( + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]] + ) def test_quantised_output(self): vq = Quantisation( @@ -32,62 +45,22 @@ def test_quantised_output(self): vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) actual = vq(self.encoded) - # This is shape (2,5,3,3) + # This is shape (2,5,3) expected = torch.Tensor( [ [ - [ - [-1.1265823841, -0.1376807690, 0.7218121886], - [-0.1376807690, 0.7218121886, 0.7218121886], - [-1.1265823841, -0.1376807690, -1.1265823841], - ], - [ - [-0.5252661109, 0.3244057596, -0.8940765262], - [0.3244057596, -0.8940765262, -0.8940765262], - [-0.5252661109, 0.3244057596, -0.5252661109], - ], - [ - [-0.9950634241, -1.0523844957, 1.7175949812], - [-1.0523844957, 1.7175949812, 1.7175949812], - [-0.9950634241, -1.0523844957, -0.9950634241], - ], - [ - [0.2679379284, -0.4480970800, -0.3190571964], - [-0.4480970800, -0.3190571964, -0.3190571964], - [0.2679379284, -0.4480970800, 0.2679379284], - ], - [ - [-0.6253433824, -0.5198931098, -0.8529881239], - [-0.5198931098, -0.8529881239, -0.8529881239], - [-0.6253433824, -0.5198931098, -0.6253433824], - ], + [2.0, 2.0, 1.0], + [1.0, 1.0, 0.0], + [0.0, 0.0, -1.0], + [1.0, 1.0, -1.0], + [1.0, 1.0, 2.0], ], [ - [ - [-1.6703201532, -1.1265823841, -0.1376807690], - [-1.1265823841, 0.7218121886, -1.1265823841], - [-0.1376807690, -0.1376807690, -0.1376807690], - ], - [ - [0.8635767698, -0.5252661109, 0.3244057596], - [-0.5252661109, -0.8940765262, -0.5252661109], - [0.3244057596, 0.3244057596, 0.3244057596], - ], - [ - [-1.5300362110, -0.9950634241, -1.0523844957], - [-0.9950634241, 1.7175949812, -0.9950634241], - [-1.0523844957, -1.0523844957, -1.0523844957], - ], - [ - [0.5375117064, 0.2679379284, -0.4480970800], - [0.2679379284, -0.3190571964, 0.2679379284], - [-0.4480970800, -0.4480970800, -0.4480970800], - ], - [ - [-1.6273639202, -0.6253433824, -0.5198931098], - [-0.6253433824, -0.8529881239, -0.6253433824], - [-0.5198931098, -0.5198931098, -0.5198931098], - ], + [2.0, 2.0, -1.0], + [1.0, -2.0, -2.0], + [0.0, 0.0, 0.0], + [1.0, 0.0, 2.0], + [1.0, 1.0, 0.0], ], ] ) @@ -105,8 +78,47 @@ def test_quantised_shape(self): vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) output = vq(self.encoded) actual = torch.tensor(output.shape) - expected = torch.tensor([2, 5, 3, 3]) + expected = torch.tensor([2, 5, 3]) assert torch.equal( actual, expected ), f"actual shape: {actual}, expected shape: {expected}" + + def test_preprocess(self): + vq = Quantisation( + num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim + ) + encoded_flat, permuted_shape = vq._preprocess(self.encoded) + + expected_flat_shape = torch.tensor([6, 5]) + expected_permuted_shape = torch.tensor([2, 3, 5]) + + actual_flat_shape = torch.tensor(encoded_flat.shape) + actual_permuted_shape = torch.tensor(permuted_shape) + + assert torch.equal( + actual_flat_shape, expected_flat_shape + ), f"actual flattened shape: {actual_flat_shape}, expected flattened shape: {expected_flat_shape}" + + assert torch.equal( + actual_permuted_shape, expected_permuted_shape + ), f"actual permuted shape: {actual_permuted_shape}, expected permuted shape: {expected_permuted_shape}" + + def test_preprocess_channel_dim_assertion(self): + vq = Quantisation( + num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim + ) + with self.assertRaises(Exception): + encoded_flat, permuted_shape = vq._preprocess(self.encoded[:, :4, :]) + + def test_postprocess(self): + vq = Quantisation( + num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim + ) + quantised = vq._postprocess(self.test_tensor_flat, torch.Size([2, 2, 3])) + actual_quantised_shape = torch.tensor(quantised.shape) + expected_quantised_shape = torch.tensor([2, 3, 2]) + + assert torch.equal( + actual_quantised_shape, expected_quantised_shape + ), f"actual quantised shape: {actual_quantised_shape}, expected quantised shape: {expected_quantised_shape}" diff --git a/torchmultimodal/modules/layers/quantisation.py b/torchmultimodal/modules/layers/quantisation.py index 220276d8..ffbbbfff 100644 --- a/torchmultimodal/modules/layers/quantisation.py +++ b/torchmultimodal/modules/layers/quantisation.py @@ -5,19 +5,22 @@ # LICENSE file in the root directory of this source tree. import torch -from torch import nn, Tensor +from torch import nn, Tensor, Size class Quantisation(nn.Module): - """Quantisation provides an embedding layer that takes in a collection of flattened vectors, usually the - output of an encoder architecture, and performs a nearest-neighbor lookup in the embedding space. + """Quantisation provides an embedding layer that takes in the output of an encoder + and performs a nearest-neighbor lookup in the embedding space. - Vector quantisation was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) to generate high-fidelity - images, videos, and audio data. + Vector quantisation was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) + to generate high-fidelity images, videos, and audio data. Args: num_embeddings (int): the number of vectors in the embedding space embedding_dim (int): the dimensionality of the embedding vectors + + Inputs: + x (Tensor): Tensor containing a batch of encoder outputs. Expects dimensions to be batch x channel x n dims. """ def __init__(self, num_embeddings: int, embedding_dim: int): @@ -31,7 +34,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int): ) self.quantised_vectors = None - def forward(self, x: Tensor): + def _preprocess(self, x: Tensor): # Rearrange from batch x channel x n dims to batch x n dims x channel new_dims = (0,) + tuple(range(2, len(x.shape))) + (1,) x_permuted = x.permute(new_dims).contiguous() @@ -45,6 +48,18 @@ def forward(self, x: Tensor): x_flat.shape[-1] == self.embedding_dim ), f"Expected {x_flat.shape[-1]} to be embedding size of {self.embedding_dim}" + return x_flat, permuted_shape + + def _postprocess(self, quantised_flat: Tensor, permuted_shape: Size): + # Rearrange back to batch x channel x n dims + num_dims = len(permuted_shape) + quantised_permuted = quantised_flat.view(permuted_shape) + old_dims = (0,) + (num_dims - 1,) + tuple(range(1, num_dims - 1)) + quantised = quantised_permuted.permute(old_dims).contiguous() + + return quantised + + def quantise(self, x_flat: Tensor): # Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2 distances = torch.cdist(x_flat, self.embedding.weight, p=2.0) ** 2 @@ -52,14 +67,22 @@ def forward(self, x: Tensor): encoding_indices = torch.argmin(distances, dim=1) # Quantise - quantised_permuted = self.embedding(encoding_indices).view(permuted_shape) + quantised_flat = self.embedding(encoding_indices) # Straight through estimator - quantised_permuted = x_permuted + (quantised_permuted - x_permuted).detach() + quantised_flat = x_flat + (quantised_flat - x_flat).detach() - # Rearrange back to batch x channel x n dims - old_dims = (0,) + (len(x.shape) - 1,) + tuple(range(1, len(x.shape) - 1)) - quantised = quantised_permuted.permute(old_dims).contiguous() + return quantised_flat + + def forward(self, x: Tensor): + # Reshape and flatten encoder output for quantisation + x_flat, permuted_shape = self._preprocess(x) + + # Quantisation via nearest neighbor lookup + quantised_flat = self.quantise(x_flat) + + # Reshape back to original dims + quantised = self._postprocess(quantised_flat, permuted_shape) self.quantised_vectors = quantised return quantised From 14587db33b422dad81f321dd8037ecaa9af710d7 Mon Sep 17 00:00:00 2001 From: Rafi Ayub Date: Thu, 28 Apr 2022 13:30:43 -0700 Subject: [PATCH 6/7] improved structure of unit tests --- test/modules/layers/test_quantisation.py | 42 ++++--------------- test/modules/losses/test_commitment.py | 13 +++--- .../modules/layers/quantisation.py | 3 +- 3 files changed, 16 insertions(+), 42 deletions(-) diff --git a/test/modules/layers/test_quantisation.py b/test/modules/layers/test_quantisation.py index d1f0259c..4e44331e 100644 --- a/test/modules/layers/test_quantisation.py +++ b/test/modules/layers/test_quantisation.py @@ -7,7 +7,6 @@ import unittest import torch -from test.test_utils import set_rng_seed from torch import nn from torchmultimodal.modules.layers.quantisation import Quantisation @@ -18,10 +17,9 @@ class TestQuantisation(unittest.TestCase): """ def setUp(self): - torch.set_printoptions(precision=10) - set_rng_seed(4) self.num_embeddings = 4 self.embedding_dim = 5 + # This is 2x5x3 self.encoded = torch.Tensor( [ @@ -34,17 +32,17 @@ def setUp(self): [[1, 0, -1, -1, 2], [2, -2, 0, 0, 1], [2, 1, 0, 1, 1], [-1, -2, 0, 2, 0]] ) # This is 4x3 - self.test_tensor_flat = torch.Tensor( + self.input_tensor_flat = torch.Tensor( [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]] ) - def test_quantised_output(self): - vq = Quantisation( + self.vq = Quantisation( num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim ) - vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) - actual = vq(self.encoded) + self.vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) + def test_quantised_output(self): + actual = self.vq(self.encoded) # This is shape (2,5,3) expected = torch.Tensor( [ @@ -71,24 +69,8 @@ def test_quantised_output(self): msg=f"actual: {actual}, expected: {expected}", ) - def test_quantised_shape(self): - vq = Quantisation( - num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim - ) - vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights) - output = vq(self.encoded) - actual = torch.tensor(output.shape) - expected = torch.tensor([2, 5, 3]) - - assert torch.equal( - actual, expected - ), f"actual shape: {actual}, expected shape: {expected}" - def test_preprocess(self): - vq = Quantisation( - num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim - ) - encoded_flat, permuted_shape = vq._preprocess(self.encoded) + encoded_flat, permuted_shape = self.vq._preprocess(self.encoded) expected_flat_shape = torch.tensor([6, 5]) expected_permuted_shape = torch.tensor([2, 3, 5]) @@ -105,17 +87,11 @@ def test_preprocess(self): ), f"actual permuted shape: {actual_permuted_shape}, expected permuted shape: {expected_permuted_shape}" def test_preprocess_channel_dim_assertion(self): - vq = Quantisation( - num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim - ) with self.assertRaises(Exception): - encoded_flat, permuted_shape = vq._preprocess(self.encoded[:, :4, :]) + encoded_flat, permuted_shape = self.vq._preprocess(self.encoded[:, :4, :]) def test_postprocess(self): - vq = Quantisation( - num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim - ) - quantised = vq._postprocess(self.test_tensor_flat, torch.Size([2, 2, 3])) + quantised = self.vq._postprocess(self.input_tensor_flat, torch.Size([2, 2, 3])) actual_quantised_shape = torch.tensor(quantised.shape) expected_quantised_shape = torch.tensor([2, 3, 2]) diff --git a/test/modules/losses/test_commitment.py b/test/modules/losses/test_commitment.py index 44248eaa..5c998de9 100644 --- a/test/modules/losses/test_commitment.py +++ b/test/modules/losses/test_commitment.py @@ -7,7 +7,6 @@ import unittest import torch -from test.test_utils import set_rng_seed from torchmultimodal.modules.losses.vqvae import CommitmentLoss @@ -17,17 +16,15 @@ class TestCommitment(unittest.TestCase): """ def setUp(self): - torch.set_printoptions(precision=10) - set_rng_seed(4) - self.quantized = torch.randn((2, 3)) - self.encoded = torch.randn((2, 3)) + self.quantized = torch.Tensor([[-1, 0, 1], [2, 1, 0]]) + self.encoded = torch.Tensor([[-2, -1, 0], [0, 2, -2]]) + self.commitment = CommitmentLoss() def test_loss_value(self): - commitment = CommitmentLoss() - loss = commitment(self.quantized, self.encoded) + loss = self.commitment(self.quantized, self.encoded) actual = loss.item() - expected = 1.2070025206 + expected = 2.0 torch.testing.assert_close( actual, diff --git a/torchmultimodal/modules/layers/quantisation.py b/torchmultimodal/modules/layers/quantisation.py index ffbbbfff..ee169def 100644 --- a/torchmultimodal/modules/layers/quantisation.py +++ b/torchmultimodal/modules/layers/quantisation.py @@ -20,7 +20,8 @@ class Quantisation(nn.Module): embedding_dim (int): the dimensionality of the embedding vectors Inputs: - x (Tensor): Tensor containing a batch of encoder outputs. Expects dimensions to be batch x channel x n dims. + x (Tensor): Tensor containing a batch of encoder outputs. + Expects dimensions to be batch x channel x n dims. """ def __init__(self, num_embeddings: int, embedding_dim: int): From 00b2049c1eb9a5e120a0a86fe34f9dde900795fc Mon Sep 17 00:00:00 2001 From: Rafi Ayub Date: Tue, 3 May 2022 10:14:02 -0700 Subject: [PATCH 7/7] fixing lint issue with assertion test --- test/modules/layers/test_quantisation.py | 2 +- torchmultimodal/modules/layers/quantisation.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/modules/layers/test_quantisation.py b/test/modules/layers/test_quantisation.py index 4e44331e..2a524191 100644 --- a/test/modules/layers/test_quantisation.py +++ b/test/modules/layers/test_quantisation.py @@ -87,7 +87,7 @@ def test_preprocess(self): ), f"actual permuted shape: {actual_permuted_shape}, expected permuted shape: {expected_permuted_shape}" def test_preprocess_channel_dim_assertion(self): - with self.assertRaises(Exception): + with self.assertRaises(ValueError): encoded_flat, permuted_shape = self.vq._preprocess(self.encoded[:, :4, :]) def test_postprocess(self): diff --git a/torchmultimodal/modules/layers/quantisation.py b/torchmultimodal/modules/layers/quantisation.py index ee169def..43b744ce 100644 --- a/torchmultimodal/modules/layers/quantisation.py +++ b/torchmultimodal/modules/layers/quantisation.py @@ -45,9 +45,10 @@ def _preprocess(self, x: Tensor): x_flat = x_permuted.view(-1, permuted_shape[-1]) # channel dimension should be embedding dim so that each element in encoder # output volume gets associated with single embedding vector - assert ( - x_flat.shape[-1] == self.embedding_dim - ), f"Expected {x_flat.shape[-1]} to be embedding size of {self.embedding_dim}" + if x_flat.shape[-1] != self.embedding_dim: + raise ValueError( + f"Expected {x_flat.shape[-1]} to be embedding size of {self.embedding_dim}" + ) return x_flat, permuted_shape