Skip to content

Latest commit

 

History

History
82 lines (63 loc) · 5.4 KB

README.md

File metadata and controls

82 lines (63 loc) · 5.4 KB

Reinforcement Learning for Consistency Models (RLCM)

This is the official implementation of the paper RL for Consistency Models: Faster Reward Guided Text-to-Image Generation. We support low rank adaptation (LoRA) for finetuning a latent consistency model (LCM). Much of this code was inspired by the repository for Denoising Diffusion Policy Optimization (DDPO).

UPDATE Now with REBEL suppot! Check out the rebel paper for more information, or just turn it on at training.algorithm of config.yaml. You may have to increase the learning rate (and decrease train batch size depending on your gpu), but otherwise all hyperparameters should be the same.

image

Installation

In order to install the required clone this repository and run the setup.py file:

git clone https://github.com/Owen-Oertell/rlcm.git
cd rlcm
pip install -e . 

you must have python >=3.10 installed.

RLCM Training

RLCM is run by navigating into the scripts folder and then running the main.py file using accelerate. By default, we using the compression task but other tasks can be used.

To run the aesthetic task for example, we can run the following command.

accelerate launch main.py task=aesthetic

There are four tasks available: prompt_image_alignment, aesthetic, compression, and incompression. For more discussion about the tasks, please see the paper.

RLCM Inference

We also provide a sample inference script once you have saved your models to disk. This script is located in the scripts folder and is called inference.py. You can run this script by running the following command (after editing it to point to your saved model and update if you changed the config):

python inference.py

Summary of Hyperparameters

Below is a summary of the hyperparameters that can be used to train. Each task has its own hyperparameters. The default ones are given in lcm_rl_pytorch/configs and can be overwritten by passing them as arguments to the main.py script.

  • model. The model that we are using to train. To use without modification it must use the LCM pipeline. This code uses SimianLuo/LCM_Dreamshaper_v7.
  • lr. The learning rate for the policy gradient updates.
  • sample_batch_size_per_gpu. The number of samples that are being collected per forward pass of the model.
  • train_batch_size_per_gpu. How many we train after rebatching. For memory reasons, this is (much) smaller than the sample_batch_size_per_gpu.
  • num_sample_iters. The number of times we do a forward pass for collecting samples before doing updates.
  • clip_range. Clip range for the PG update.
  • max_grad_norm. Max gradient norm we use for clipping.
  • adv_clip_max. Max advantage when clipping advantages.
  • gradient_accumulation_steps. Number of gradient accumulation steps.
  • num_epochs. Number of epochs to train for.
  • num_inference_steps. Number of inference steps for the LCM model. Usually higher is better quality but slower train and inference time.
  • batches_per_epoch. Number of batches per epoch.
  • stat_buffer_size. Rolling buffer size, used for computing the advantages
  • stat_min_count. Minimum number of samples in the buffer before we start computing advantages per prompt instead of via a global mean.
  • num_inner_epochs. Number of updates on a given batch. From our testing, 1 is usually sufficient.
  • lora_alpha. Low rank adapatation alpha parameter.
  • lora_r. Low rank adaptation rank parameter.
  • dataset. Name of prompt dataset given in lcm_rl_pytorch/core/dataloader.py
  • reward_fn. Name of reward function given in lcm_rl_pytorch/rewards/rewards.py
  • port. For prompt image alignment task, we also have a port parameter to match the server port.

Reproducing Results

We reproduce our results from the paper below. Please see the appendix for full information of hyperparameters and number of gpus used. At a high level however, we used 4 RTX a6000 gpus for each of the tasks where the prompt image alignment task was run with 3 gpus and 1 gpus for the server (from kevin's repo. Make sure to use to 13b parameter version of LLaVA otherwise we've experienced empty outputs).

image

Plots of performance by runtime measured by GPU hours. We report the runtime on four NVIDIA RTX A6000 across three random seeds and plot the mean and standard deviation. We observe that in all tasks RLCM noticeably reduces the training time while achieving comparable or better reward score performance.

image

Training curves for RLCM and DDPO by number of reward queries on compressibility, incompressibility, aesthetic, and prompt image alignment. We plot three random seeds for each algorithm and plot the mean and standard deviation across those seeds. RLCM seems to produce either comparable or better reward optimization performance across these tasks.

Citation

@misc{oertell2024rl,
      title={RL for Consistency Models: Faster Reward Guided Text-to-Image Generation}, 
      author={Owen Oertell and Jonathan D. Chang and Yiyi Zhang and Kianté Brantley and Wen Sun},
      year={2024},
      eprint={2404.03673},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}