Transfusion [Paper]
- Transfusion is a Multi-Modal Transformer, it can generate text like GPTs and images like Diffusion Models, all at once in one go not separately!
- It can easily switch between text and image modalities for generations, and it is nothing complicated, just a single transformer with some modality-specific components!
- This can easily be extended to other modalities like videos, audio, etc, but for now, it can only take images and text as input
- For now I have "test-trained" it on
- Fashion MNIST Dataset (contains images of Fashion Items like T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot
- MNIST Dataset (contains images of Digits from 0 to 9)
- I have taken the classes as text and trained it. See below for some generated examples...
TODO
: Train on a large Multi-Modal Dataset (something like tiny stories dataset with images in between illustrating the story...?)
- Test Trained on Fashion MNIST Dataset <===> Training Notebook with some generated samples
- Test Trained on MNIST dataset <===> Training Notebook with some generated samples
- Important Snippets from the Paper
-
Can produce 2 images of Fashion Items along with the text (in the form of tokens) shown above the respective images the integers above the images can be interpreted using this dictionary
{'T-shirt/top': 0, 'Trouser': 1, 'Pullover': 2, 'Dress': 3, 'Coat': 4, 'Sandal': 5, 'Shirt': 6, 'Sneaker': 7, 'Bag': 8, 'Ankle boot': 9}
So
5
means it's a sandal and0
means it's a T-shirt/top from the below image and just like that some more examples. Use the dictionary to interpret the tokens as text (for now, will change it)
-
See this notebook for more examples.
-
Generates text and images in an alternating way as shown below
-
See this notebook for more examples
- Transfusion by pretraining a transformer model on 50% text and 50% image data using a different objective for each modality: next token prediction for text and diffusion for images
- We apply causal attention for text tokens and bidirectional attention for image patches. For inference, we introduce a decoding algorithm that combines the standard practices of text generation from language models and image generation from diffusion models
- Intra-image bidirectional attention is important, and replacing it with causal attention hurts text-to-image generation
- Autoregressive Classification
- Usual Cross-Entropy Loss
- Noise Schedule: cosine scheduler (We found that while the linear noise schedule used in Ho et al. (2020) worked well for high-resolution images, it was sub-optimal for images of resolution 64 × 64 and 32 × 32)
- Loss: Mean Squared Error
- Latent Image Representation: Variational autoencoders (VAEs) [Kingma and Welling, 2013] can save compute by encoding images into a lower-dimensional latent space
- Discrete text and continuous images
- Each text string is tokenized into a sequence of discrete tokens from a fixed vocabulary, where each token is represented as an integer
- The vast majority of the model’s parameters belong to a single transformer, which processes every sequence, regardless of modality (We follow Llama’s [Touvron et al., 2023a] flavour of the transformer block, which includes the SwiGLU activation function [Shazeer, 2020] and RoPE [Su et al., 2024])
- To convert our data into this space, we use lightweight modality-specific components with unshared parameters
- For text, these are the embedding matrices
- Images, we experiment with two alternatives for compressing local windows of k × k patch vectors into a single transformer vector (and vice versa):
- a simple linear layer (We add an embedding of the timestep t to every patch vector before the linear layer)
- up and down blocks of a U-Net (We replace the U-Net’s AdaLayerNorm with regular layer norm in our implementation)
- Transfusion Attention: While text is naturally sequential, images are not, and are usually modelled with unrestricted (bidirectional) attention. Transfusion combines both attention patterns by applying causal attention to every element in the sequence, and bidirectional attention within the aspects of each individual image
- LM loss is computed per token (When the input is a BOI token, we do not compute any loss), while diffusion loss is computed per image, which may span multiple elements (image patches) in the sequence
- Specifically, we add noise ϵ to each input latent image x0 according to the diffusion process to produce xt before patchification, and then compute the image-level diffusion loss
- AdamW => | betas=(0.9, 0.95) | eps=1e-8 | lr=3e-4 | warmup=4000 | min_lr=1.5e-5 | weight_decay=0.1 | clip_norm=1.0 |
- balancing_coeff (lambda in loss function) = 5