-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FSQ implementation #74
Merged
lucidrains
merged 14 commits into
lucidrains:fsq
from
sekstini:finite-scalar-quantization
Sep 29, 2023
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
6b1ca27
Add FSQ implementation
sekstini d373463
Make it importable
sekstini 79aea46
Register buffers and add forward method
sekstini 01d9a1f
Add FSQ Autoencoder example
sekstini 1245c29
Remove print
sekstini e7549b2
That one too...
sekstini a832fc9
Add offset to even instead of odd
sekstini c989d30
Return zhat in forward
sekstini d403248
Fix gradient flow and reduce param count on FSQ model
sekstini 0e66f53
Add dim and n_codes properties
sekstini 1f68476
Bump version to 1.8.0
sekstini 64ce9e9
Add FSQ entry to README
sekstini bb9f7a9
Add FSQ png
sekstini e02ebab
Add FSQ description and table
sekstini File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# FashionMnist VQ experiment with various settings, using FSQ. | ||
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py | ||
|
||
from tqdm.auto import trange | ||
|
||
import math | ||
import torch | ||
import torch.nn as nn | ||
from torchvision import datasets, transforms | ||
from torch.utils.data import DataLoader | ||
|
||
from vector_quantize_pytorch import FSQ | ||
|
||
|
||
lr = 3e-4 | ||
train_iter = 1000 | ||
levels = [8, 6, 5] # target size 2^8, actual size 240 | ||
num_codes = math.prod(levels) | ||
seed = 1234 | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
class SimpleFSQAutoEncoder(nn.Module): | ||
def __init__(self, levels: list[int]): | ||
super().__init__() | ||
self.layers = nn.ModuleList( | ||
[ | ||
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), | ||
nn.MaxPool2d(kernel_size=2, stride=2), | ||
nn.GELU(), | ||
nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1), | ||
nn.Conv2d(8, 8, kernel_size=6, stride=3, padding=0), | ||
FSQ(levels), | ||
nn.ConvTranspose2d(8, 8, kernel_size=6, stride=3, padding=0), | ||
nn.Conv2d(8, 16, kernel_size=4, stride=1, padding=2), | ||
nn.GELU(), | ||
nn.Upsample(scale_factor=2, mode="nearest"), | ||
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=2), | ||
] | ||
) | ||
return | ||
|
||
def forward(self, x): | ||
for layer in self.layers: | ||
if isinstance(layer, FSQ): | ||
x, indices = layer(x) | ||
else: | ||
x = layer(x) | ||
|
||
return x.clamp(-1, 1), indices | ||
|
||
|
||
def train(model, train_loader, train_iterations=1000): | ||
def iterate_dataset(data_loader): | ||
data_iter = iter(data_loader) | ||
while True: | ||
try: | ||
x, y = next(data_iter) | ||
except StopIteration: | ||
data_iter = iter(data_loader) | ||
x, y = next(data_iter) | ||
yield x.to(device), y.to(device) | ||
|
||
for _ in (pbar := trange(train_iterations)): | ||
opt.zero_grad() | ||
x, _ = next(iterate_dataset(train_loader)) | ||
out, indices = model(x) | ||
rec_loss = (out - x).abs().mean() | ||
rec_loss.backward() | ||
|
||
opt.step() | ||
pbar.set_description( | ||
f"rec loss: {rec_loss.item():.3f} | " | ||
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}" | ||
) | ||
return | ||
|
||
|
||
transform = transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] | ||
) | ||
train_dataset = DataLoader( | ||
datasets.FashionMNIST( | ||
root="~/data/fashion_mnist", train=True, download=True, transform=transform | ||
), | ||
batch_size=256, | ||
shuffle=True, | ||
) | ||
|
||
print("baseline") | ||
torch.random.manual_seed(seed) | ||
model = SimpleFSQAutoEncoder(levels).to(device) | ||
opt = torch.optim.AdamW(model.parameters(), lr=lr) | ||
train(model, train_dataset, train_iterations=train_iter) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize | ||
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ | ||
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer | ||
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer | ||
from vector_quantize_pytorch.finite_scalar_quantization import FSQ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
""" | ||
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 | ||
Code adapted from Jax version in Appendix A.1 | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def round_ste(z: torch.Tensor) -> torch.Tensor: | ||
"""Round with straight through gradients.""" | ||
zhat = z.round() | ||
return z + (zhat - z).detach() | ||
|
||
|
||
class FSQ(nn.Module): | ||
def __init__(self, levels: list[int]): | ||
super().__init__() | ||
_levels = torch.tensor(levels, dtype=torch.int32) | ||
self.register_buffer("_levels", _levels) | ||
|
||
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) | ||
self.register_buffer("_basis", _basis) | ||
|
||
self.dim = len(levels) | ||
self.n_codes = self._levels.prod().item() | ||
implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes)) | ||
self.register_buffer("implicit_codebook", implicit_codebook) | ||
|
||
def forward(self, z: torch.Tensor) -> torch.Tensor: | ||
zhat = self.quantize(z) | ||
indices = self.codes_to_indices(zhat) | ||
return zhat, indices | ||
|
||
def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: | ||
"""Bound `z`, an array of shape (..., d).""" | ||
half_l = (self._levels - 1) * (1 - eps) / 2 | ||
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) | ||
sekstini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
shift = (offset / half_l).tan() | ||
return (z + shift).tanh() * half_l - offset | ||
|
||
def quantize(self, z: torch.Tensor) -> torch.Tensor: | ||
"""Quanitzes z, returns quantized zhat, same shape as z.""" | ||
quantized = round_ste(self.bound(z)) | ||
half_width = self._levels // 2 # Renormalize to [-1, 1]. | ||
return quantized / half_width | ||
|
||
def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: | ||
half_width = self._levels // 2 | ||
return (zhat_normalized * half_width) + half_width | ||
|
||
def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: | ||
half_width = self._levels // 2 | ||
return (zhat - half_width) / half_width | ||
|
||
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: | ||
"""Converts a `code` to an index in the codebook.""" | ||
assert zhat.shape[-1] == self.dim | ||
zhat = self._scale_and_shift(zhat) | ||
return (zhat * self._basis).sum(dim=-1).to(torch.int32) | ||
|
||
def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor: | ||
"""Inverse of `codes_to_indices`.""" | ||
indices = indices.unsqueeze(-1) | ||
codes_non_centered = (indices // self._basis) % self._levels | ||
return self._scale_and_shift_inverse(codes_non_centered) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for porting this! Do you mind if we link this repo in the next version and our own public code release?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LMK if you are also planning to update the README and I can send some figs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please do! and i believe @lucidrains and @sekstini will appreciate it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, go ahead 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fab-jul
That would be great 🙏