Text-to-image generative models represent a powerful and innovative approach for creating visual artwork. The rising popularity of these models has given rise to the new field of prompt engineering. While there has been significant progress in prompt engineering for text generation purposes, less work has been done to rigorously examine how users can prompt generative frameworks with natural language for visual generation purposes.
In this project, we propose a novel Transformer based ensemble model for the task of predicting the text prompt given a generated image. The predicted text prompt can then be edited and used to generate new images similar to the existing one. Our proposed ensemble model uses embeddings derived from several models such as ConvNext, CLIP and BLIP, and leverages the attention mechanism to fuse these embeddings using a transformer encoder model. We train and evaluate our proposed transformer ensemble model using a large dataset of (prompt, image) pairs from DiffusionDB, and show that our model is able to generate text prompts similar to the prompts used to generate the image.
The Jupyter notebook demonstrating our proposed model pipeline including embedding generation and inference of our Transformer ensemble model can be found at demo_notebook.ipynb.
Please make sure that you have downloaded the our Transformer ensemble model weights before attempting to run the demo. Links to the weights can be found in the Evaluation and Checkpoints section below.
We used Google Colaboratory for all tasks in our project including data-processing and training. Install the following additional requirements on top of the default environment provided by Google Colab.
!pip install sentence_transformers
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install pytorch_lightning torchmetrics
!pip install clip-interrogator==0.6.0
We use a set of 93000 (image, prompt) pairs from the Diffusion DB dataset. You can download the first 100 parts and unzip the data as described here.
Then you can use the following commands to generate the embeddings required by our ensemble model, including the ground truth prompt embeddings, CLIP model embeddings, ConvNext model embeddings and the CLIP Interrogator embeddings.
python gt_prompt_embedding_creation.py
python clip_embedding_creation.py
python convnext_embedding_creation.py
python clip_intgtr_embedding_creation.py
We have precomputed the embeddings required to train and evaluate our model. You can find all the four embeddings (ground truth prompt embeddings, CLIP model embeddings, ConvNext model embeddings and the CLIP Interrogator embeddings) for the 3 datasets through the Google Drive links in the Table below.
10k dataset | 43k dataset | 93k dataset |
---|---|---|
10K_embeddings | 43K_embeddings | 93K_embeddings |
Please see the config files config-10k.yml, config-43k.yml, config-93k.yml for settings hyperparameters and paths to data files.
The command for training is,
python train.py config-10k.yml
The average cosine similarity of our ensemble models on the validation sets of our 10k, 43k and 93k datasets is shown below. The model checkpoints for our Transformer Ensemble Model are also linked below.
10k dataset | 43k dataset | 93k dataset | |
---|---|---|---|
Average cosine similarity | 0.701 | 0.687 | 0.690 |
Model checkpoint | 10K_ensemble_model.ckpt | 43K_ensemble_model.ckpt | 93K_ensemble_model.ckpt |
The command to perform evaluation of our Transformer ensemble model using the above model weights is,
python evaluate.py config-10k.yml "./10k_ensemble_model.ckpt"