Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FSQ implementation #74

Merged
merged 14 commits into from
Sep 29, 2023
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,40 @@ indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`

### Finite Scalar Quantization

<img src="./fsq.png" width="500px"></img>

| | VQ | FSQ |
|------------------|----|-----|
| Quantization | argmin_c \|\| z-c \|\| | round(f(z)) |
| Gradients | Straight Through Estimation (STE) | STE |
| Auxiliary Losses | Commitment, codebook, entropy loss, ... | N/A |
| Tricks | EMA on codebook, codebook splitting, projections, ...| N/A |
| Parameters | Codebook | N/A |

[This](https://arxiv.org/abs/2309.15505) work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube.


```python
import torch
from vector_quantize_pytorch import FSQ

levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)

x = torch.randn(1, 1024, quantizer.dim)
xhat, indices = quantizer(x)

print(xhat.shape) # (1, 1024, 4) - (batch, seq, dim)
print(indices.shape) # (1, 1024) - (batch, seq)

assert torch.all(xhat == quantizer.indices_to_codes(indices))
assert torch.all(xhat == quantizer.implicit_codebook[indices])
```



## Todo

- [x] allow for multi-headed codebooks
Expand Down Expand Up @@ -384,3 +418,14 @@ This repository should also automatically synchronizing the codebooks in a multi
primaryClass = {cs.CV}
}
```

```bibtex
@misc{mentzer2023finite,
title = {Finite Scalar Quantization: VQ-VAE Made Simple},
author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
year = {2023},
eprint = {2309.15505},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
2 changes: 1 addition & 1 deletion examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, **vq_kwargs):
def forward(self, x):
for layer in self.layers:
if isinstance(layer, VectorQuantize):
x_flat, indices, commit_loss = layer(x)
x, indices, commit_loss = layer(x)
else:
x = layer(x)

Expand Down
94 changes: 94 additions & 0 deletions examples/autoencoder_fsq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# FashionMnist VQ experiment with various settings, using FSQ.
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py

from tqdm.auto import trange

import math
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from vector_quantize_pytorch import FSQ


lr = 3e-4
train_iter = 1000
levels = [8, 6, 5] # target size 2^8, actual size 240
num_codes = math.prod(levels)
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"


class SimpleFSQAutoEncoder(nn.Module):
def __init__(self, levels: list[int]):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
nn.Conv2d(8, 8, kernel_size=6, stride=3, padding=0),
FSQ(levels),
nn.ConvTranspose2d(8, 8, kernel_size=6, stride=3, padding=0),
nn.Conv2d(8, 16, kernel_size=4, stride=1, padding=2),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=2),
]
)
return

def forward(self, x):
for layer in self.layers:
if isinstance(layer, FSQ):
x, indices = layer(x)
else:
x = layer(x)

return x.clamp(-1, 1), indices


def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)

for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices = model(x)
rec_loss = (out - x).abs().mean()
rec_loss.backward()

opt.step()
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return


transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)

print("baseline")
torch.random.manual_seed(seed)
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
Binary file added fsq.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '1.7.1',
version = '1.8.0',
license='MIT',
description = 'Vector Quantization - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
3 changes: 2 additions & 1 deletion vector_quantize_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
66 changes: 66 additions & 0 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""

import torch
import torch.nn as nn


def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()


class FSQ(nn.Module):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for porting this! Do you mind if we link this repo in the next version and our own public code release?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK if you are also planning to update the README and I can send some figs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do! and i believe @lucidrains and @sekstini will appreciate it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, go ahead 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fab-jul

LMK if you are also planning to update the README and I can send some figs.

That would be great 🙏

def __init__(self, levels: list[int]):
super().__init__()
_levels = torch.tensor(levels, dtype=torch.int32)
self.register_buffer("_levels", _levels)

_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
self.register_buffer("_basis", _basis)

self.dim = len(levels)
self.n_codes = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes))
self.register_buffer("implicit_codebook", implicit_codebook)

def forward(self, z: torch.Tensor) -> torch.Tensor:
zhat = self.quantize(z)
indices = self.codes_to_indices(zhat)
return zhat, indices

def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 - eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
sekstini marked this conversation as resolved.
Show resolved Hide resolved
shift = (offset / half_l).tan()
return (z + shift).tanh() * half_l - offset

def quantize(self, z: torch.Tensor) -> torch.Tensor:
"""Quanitzes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width

def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width

def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width

def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(torch.int32)

def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
"""Inverse of `codes_to_indices`."""
indices = indices.unsqueeze(-1)
codes_non_centered = (indices // self._basis) % self._levels
return self._scale_and_shift_inverse(codes_non_centered)