Skip to content

Commit

Permalink
[WIP]Flax training script for controlnet (#2818)
Browse files Browse the repository at this point in the history
* add train_controlnet_flax

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
yiyixuxu and patrickvonplaten authored Mar 28, 2023
1 parent 58fc824 commit d4f846f
Show file tree
Hide file tree
Showing 2 changed files with 1,097 additions and 0 deletions.
96 changes: 96 additions & 0 deletions examples/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,99 @@ image = pipe(
image.save("./output.png")
```

## Training with Flax/JAX

For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

### Running on Google Cloud TPU

See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax).

First create a single TPUv4-8 VM and connect to it:

```
ZONE=us-central2-b
TPU_TYPE=v4-8
VM_NAME=hg_flax
gcloud alpha compute tpus tpu-vm create $VM_NAME \
--zone $ZONE \
--accelerator-type $TPU_TYPE \
--version tpu-vm-v4-base
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
```

When connected install JAX `0.4.5`:

```
pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

To verify that JAX was correctly installed, you can run the following command:

```
import jax
jax.device_count()
```

This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM.

Then install Diffusers and the library's training dependencies:

```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

Then cd in the example folder and run

```bash
pip install -U -r requirements_flax.txt
```

Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress

```
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):
```
huggingface-cli login
```
Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
```
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet"
```
And finally start the training
```
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=fusing/fill50k \
--resolution=512 \
--learning_rate=1e-5 \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--validation_steps=1000 \
--train_batch_size=2 \
--revision="non-ema" \
--from_pt \
--report_to="wandb" \
--max_train_steps=10000 \
--push_to_hub \
--hub_model_id=$HUB_MODEL_ID
```
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
Loading

0 comments on commit d4f846f

Please sign in to comment.