diff --git a/README.md b/README.md index 9e9b9a1..a93fa8f 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,40 @@ indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks) This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False` +### Finite Scalar Quantization + + + +| | VQ | FSQ | +|------------------|----|-----| +| Quantization | argmin_c \|\| z-c \|\| | round(f(z)) | +| Gradients | Straight Through Estimation (STE) | STE | +| Auxiliary Losses | Commitment, codebook, entropy loss, ... | N/A | +| Tricks | EMA on codebook, codebook splitting, projections, ...| N/A | +| Parameters | Codebook | N/A | + +[This](https://arxiv.org/abs/2309.15505) work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube. + + +```python +import torch +from vector_quantize_pytorch import FSQ + +levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper +quantizer = FSQ(levels) + +x = torch.randn(1, 1024, quantizer.dim) +xhat, indices = quantizer(x) + +print(xhat.shape) # (1, 1024, 4) - (batch, seq, dim) +print(indices.shape) # (1, 1024) - (batch, seq) + +assert torch.all(xhat == quantizer.indices_to_codes(indices)) +assert torch.all(xhat == quantizer.implicit_codebook[indices]) +``` + + + ## Todo - [x] allow for multi-headed codebooks @@ -384,3 +418,14 @@ This repository should also automatically synchronizing the codebooks in a multi primaryClass = {cs.CV} } ``` + +```bibtex +@misc{mentzer2023finite, + title = {Finite Scalar Quantization: VQ-VAE Made Simple}, + author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen}, + year = {2023}, + eprint = {2309.15505}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` diff --git a/examples/autoencoder.py b/examples/autoencoder.py index 0be3f53..676814a 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -41,7 +41,7 @@ def __init__(self, **vq_kwargs): def forward(self, x): for layer in self.layers: if isinstance(layer, VectorQuantize): - x_flat, indices, commit_loss = layer(x) + x, indices, commit_loss = layer(x) else: x = layer(x) diff --git a/examples/autoencoder_fsq.py b/examples/autoencoder_fsq.py new file mode 100644 index 0000000..2f3e9fc --- /dev/null +++ b/examples/autoencoder_fsq.py @@ -0,0 +1,94 @@ +# FashionMnist VQ experiment with various settings, using FSQ. +# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py + +from tqdm.auto import trange + +import math +import torch +import torch.nn as nn +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +from vector_quantize_pytorch import FSQ + + +lr = 3e-4 +train_iter = 1000 +levels = [8, 6, 5] # target size 2^8, actual size 240 +num_codes = math.prod(levels) +seed = 1234 +device = "cuda" if torch.cuda.is_available() else "cpu" + + +class SimpleFSQAutoEncoder(nn.Module): + def __init__(self, levels: list[int]): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1), + nn.Conv2d(8, 8, kernel_size=6, stride=3, padding=0), + FSQ(levels), + nn.ConvTranspose2d(8, 8, kernel_size=6, stride=3, padding=0), + nn.Conv2d(8, 16, kernel_size=4, stride=1, padding=2), + nn.GELU(), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=2), + ] + ) + return + + def forward(self, x): + for layer in self.layers: + if isinstance(layer, FSQ): + x, indices = layer(x) + else: + x = layer(x) + + return x.clamp(-1, 1), indices + + +def train(model, train_loader, train_iterations=1000): + def iterate_dataset(data_loader): + data_iter = iter(data_loader) + while True: + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(data_loader) + x, y = next(data_iter) + yield x.to(device), y.to(device) + + for _ in (pbar := trange(train_iterations)): + opt.zero_grad() + x, _ = next(iterate_dataset(train_loader)) + out, indices = model(x) + rec_loss = (out - x).abs().mean() + rec_loss.backward() + + opt.step() + pbar.set_description( + f"rec loss: {rec_loss.item():.3f} | " + + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + ) + return + + +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) +train_dataset = DataLoader( + datasets.FashionMNIST( + root="~/data/fashion_mnist", train=True, download=True, transform=transform + ), + batch_size=256, + shuffle=True, +) + +print("baseline") +torch.random.manual_seed(seed) +model = SimpleFSQAutoEncoder(levels).to(device) +opt = torch.optim.AdamW(model.parameters(), lr=lr) +train(model, train_dataset, train_iterations=train_iter) diff --git a/fsq.png b/fsq.png new file mode 100644 index 0000000..0d30ad0 Binary files /dev/null and b/fsq.png differ diff --git a/setup.py b/setup.py index bb688a6..61e541f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vector_quantize_pytorch', packages = find_packages(), - version = '1.7.1', + version = '1.8.0', license='MIT', description = 'Vector Quantization - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vector_quantize_pytorch/__init__.py b/vector_quantize_pytorch/__init__.py index 74e5ec6..becc5e0 100644 --- a/vector_quantize_pytorch/__init__.py +++ b/vector_quantize_pytorch/__init__.py @@ -1,3 +1,4 @@ from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ -from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer \ No newline at end of file +from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer +from vector_quantize_pytorch.finite_scalar_quantization import FSQ \ No newline at end of file diff --git a/vector_quantize_pytorch/finite_scalar_quantization.py b/vector_quantize_pytorch/finite_scalar_quantization.py new file mode 100644 index 0000000..e34493f --- /dev/null +++ b/vector_quantize_pytorch/finite_scalar_quantization.py @@ -0,0 +1,66 @@ +""" +Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 +Code adapted from Jax version in Appendix A.1 +""" + +import torch +import torch.nn as nn + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +class FSQ(nn.Module): + def __init__(self, levels: list[int]): + super().__init__() + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis) + + self.dim = len(levels) + self.n_codes = self._levels.prod().item() + implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes)) + self.register_buffer("implicit_codebook", implicit_codebook) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + zhat = self.quantize(z) + indices = self.codes_to_indices(zhat) + return zhat, indices + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 - eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).tan() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quanitzes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.dim + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + indices = indices.unsqueeze(-1) + codes_non_centered = (indices // self._basis) % self._levels + return self._scale_and_shift_inverse(codes_non_centered)