Skip to content

RodgersLuo/open-dalle-2

Repository files navigation

Open DALL-E 2: a simplified implementation of DALL-E 2.

This is a implementation of OpenAI's DALL-E 2 [Link] [Paper] in PyTorch. This implementation is suitable for simple text-to-image generation tasks.

Generated samples on CIFAR-10 dataset: image

Generated samples on custom geometric shapes dataset:

image

Training

The full pipeline consists of 3 models: CLIP [Paper], DALL-E 2 prior and DALL-E 2 decoder.

CLIP

CLIP is a zero-shot model that learns a shared, multimodal latent representation of text captions and images. Unlike standard image classification models that use a feature extraction network and a final linear classification network, CLIP uses an image encoder and a text encoder to obtain pairs of shared embeddings of images and texts in the latent space.

To train DALL-E 2, you need train CLIP first. To train CLIP, run

python clip/train.py

You have to specify the dataset path and the path where the final model is saved in model_config.yml.

Prior

The prior generates the CLIP image embedding based on the text caption.

To train the prior, run

python dalle2/train_prior.py

similar to CLIP, you have to specify the dataset path and model saving path in model_config.yml.

Decoder

The DALL-E 2 decoder is used to generate images conditioned on CLIP image embeddings and text captions.

To train the decoder, run

python dalle2/train_decoder.py

Do not forget to specify the paths in model_config.yml.

Sampling

The example below shows how to sample images from texts

# Initialise and load CLIP
clip = CLIP(...)
clip_path = ...
clip.load_state_dict(clip_path)

# Initialise and load prior
prior = Prior(...)
prior_path = ...
prior.load_state_dict(prior_path)

# Initialise and load decoder
decoder = Decoder(...)
decoder_path = ...
decoder.load_state_dict(decoder_path)

# Initialise DALL-E 2
dalle2 = DALLE2(clip, prior, decoder)

# Set DALL-E 2 to evaluation mode
dalle2.val_mode()

# Sample the image from text caption, cf_guidance_scale is the classifier-free guidance scale
image_size = (3, 32, 32)
image = dalle2(image_size, text="a small black square and a large gold pentagon", cf_guidance_scale=2)

About

A simplified implementation of DALL-E 2 (unCLIP)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages