-
Notifications
You must be signed in to change notification settings - Fork 423
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #140 from charlielito/feature/flax_tpu
Flax Implementation for TPU support
- Loading branch information
Showing
5 changed files
with
1,502 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,389 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/charlielito/stable-diffusion-videos/blob/feature%2Fflax_tpu/flax_stable_diffusion_videos.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "z4GhhH25OdYq" | ||
}, | ||
"source": [ | ||
"# Flax Stable Diffusion Videos\n", | ||
"\n", | ||
"This notebook allows you to generate videos by interpolating the latent space of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) using TPU for faster inference.\n", | ||
"\n", | ||
"In comparison with standard Colab GPU, this runs ~6x faster after the first run. The first run is comparable to the GPU version because it compiles the code.\n", | ||
"\n", | ||
"You can either dream up different versions of the same prompt, or morph between different text prompts (with seeds set for each for reproducibility).\n", | ||
"\n", | ||
"If you like this notebook:\n", | ||
"- consider giving the [repo a star](https://github.com/nateraw/stable-diffusion-videos) ⭐️\n", | ||
"- consider following us on Github [@nateraw](https://github.com/nateraw) [@charlielito](https://github.com/charlielito)\n", | ||
"\n", | ||
"You can file any issues/feature requests [here](https://github.com/nateraw/stable-diffusion-videos/issues)\n", | ||
"\n", | ||
"Enjoy 🤗" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "dvdCBpWWOhW-" | ||
}, | ||
"source": [ | ||
"## Setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"#@title Set up JAX\n", | ||
"#@markdown If you see an error, make sure you are using a TPU backend. Select `Runtime` in the menu above, then select the option \"Change runtime type\" and then select `TPU` under the `Hardware accelerator` setting.\n", | ||
"!pip install --upgrade jax jaxlib \n", | ||
"\n", | ||
"import jax.tools.colab_tpu\n", | ||
"jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')\n", | ||
"\n", | ||
"!pip install flax diffusers transformers ftfy\n", | ||
"jax.devices()" | ||
], | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "5EZdSq4HtmcE" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "Xwfc0ej1L9A0" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"! pip install stable_diffusion_videos" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "H7UOKJhVOonb" | ||
}, | ||
"source": [ | ||
"## Run the App 🚀" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "g71hslP8OntM" | ||
}, | ||
"source": [ | ||
"### Load the Interface\n", | ||
"\n", | ||
"This step will take a couple minutes the first time you run it." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "bgSNS368L-DV" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"\n", | ||
"from jax import pmap\n", | ||
"from flax.jax_utils import replicate\n", | ||
"from flax.training.common_utils import shard\n", | ||
"from PIL import Image\n", | ||
"\n", | ||
"from stable_diffusion_videos import FlaxStableDiffusionWalkPipeline, Interface\n", | ||
"\n", | ||
"pipeline, params = FlaxStableDiffusionWalkPipeline.from_pretrained(\n", | ||
" \"CompVis/stable-diffusion-v1-4\", \n", | ||
" revision=\"bf16\", \n", | ||
" dtype=jnp.bfloat16\n", | ||
")\n", | ||
"p_params = replicate(params)\n", | ||
"\n", | ||
"interface = Interface(pipeline, params=p_params)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "kidtsR3c2P9Z" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Connect to Google Drive to Save Outputs\n", | ||
"\n", | ||
"#@markdown If you want to connect Google Drive, click the checkbox below and run this cell. You'll be prompted to authenticate.\n", | ||
"\n", | ||
"#@markdown If you just want to save your outputs in this Colab session, don't worry about this cell\n", | ||
"\n", | ||
"connect_google_drive = True #@param {type:\"boolean\"}\n", | ||
"\n", | ||
"#@markdown Then, in the interface, use this path as the `output` in the Video tab to save your videos to Google Drive:\n", | ||
"\n", | ||
"#@markdown > /content/gdrive/MyDrive/stable_diffusion_videos\n", | ||
"\n", | ||
"\n", | ||
"if connect_google_drive:\n", | ||
" from google.colab import drive\n", | ||
"\n", | ||
" drive.mount('/content/gdrive')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "VxjRVNnMOtgU" | ||
}, | ||
"source": [ | ||
"### Launch\n", | ||
"\n", | ||
"This cell launches a Gradio Interface. Here's how I suggest you use it:\n", | ||
"\n", | ||
"1. Use the \"Images\" tab to generate images you like.\n", | ||
" - Find two images you want to morph between\n", | ||
" - These images should use the same settings (guidance scale, height, width)\n", | ||
" - Keep track of the seeds/settings you used so you can reproduce them\n", | ||
"\n", | ||
"2. Generate videos using the \"Videos\" tab\n", | ||
" - Using the images you found from the step above, provide the prompts/seeds you recorded\n", | ||
" - Set the `num_interpolation_steps` - for testing you can use a small number like 3 or 5, but to get great results you'll want to use something larger (60-200 steps). \n", | ||
"\n", | ||
"💡 **Pro tip** - Click the link that looks like `https://<id-number>.gradio.app` below , and you'll be able to view it in full screen." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "8es3_onUOL3J" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"interface.launch(debug=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "mFCoTvlnPi4u" | ||
}, | ||
"source": [ | ||
"---" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "SjTQLCiLOWeo" | ||
}, | ||
"source": [ | ||
"## Use `walk` programmatically\n", | ||
"\n", | ||
"The other option is to not use the interface, and instead use `walk` programmatically. Here's how you would do that..." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "fGQPClGwOR9R" | ||
}, | ||
"source": [ | ||
"First we define a helper fn for visualizing videos in colab" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "GqTWc8ZhNeLU" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from IPython.display import HTML\n", | ||
"from base64 import b64encode\n", | ||
"\n", | ||
"def visualize_video_colab(video_path):\n", | ||
" mp4 = open(video_path,'rb').read()\n", | ||
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", | ||
" return HTML(\"\"\"\n", | ||
" <video width=400 controls>\n", | ||
" <source src=\"%s\" type=\"video/mp4\">\n", | ||
" </video>\n", | ||
" \"\"\" % data_url)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "Vd_RzwkoPM7X" | ||
}, | ||
"source": [ | ||
"Walk! 🚶♀️" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "Hv2wBZXXMQ-I" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"video_path = pipeline.walk(\n", | ||
" p_params,\n", | ||
" ['a cat', 'a dog'],\n", | ||
" [42, 1337],\n", | ||
" fps=5, # use 5 for testing, 25 or 30 for better quality\n", | ||
" num_interpolation_steps=30, # use 3-5 for testing, 30 or more for better results\n", | ||
" height=512, # use multiples of 64 if > 512. Multiples of 8 if < 512.\n", | ||
" width=512, # use multiples of 64 if > 512. Multiples of 8 if < 512.\n", | ||
" jit=True # To use all TPU cores\n", | ||
")\n", | ||
"visualize_video_colab(video_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "oLXULBMwSDnY" | ||
}, | ||
"source": [ | ||
"### Bonus! Music videos\n", | ||
"\n", | ||
"First, we'll need to install `youtube-dl`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"%%capture\n", | ||
"! pip install youtube-dl" | ||
], | ||
"metadata": { | ||
"id": "302zMC44aiC6" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Then, we can download an example music file. Here we download one from my soundcloud:" | ||
], | ||
"metadata": { | ||
"id": "Q3gCLCkLanzO" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"! youtube-dl -f bestaudio --extract-audio --audio-format mp3 --audio-quality 0 -o \"music/thoughts.%(ext)s\" https://soundcloud.com/nateraw/thoughts" | ||
], | ||
"metadata": { | ||
"id": "rEsTe_ujagE5" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from IPython.display import Audio\n", | ||
"\n", | ||
"Audio(filename='music/thoughts.mp3')" | ||
], | ||
"metadata": { | ||
"id": "RIKA-l5la28j" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "DsIxXFTKSG5j" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Seconds in the song\n", | ||
"audio_offsets = [7, 9]\n", | ||
"fps = 8\n", | ||
"\n", | ||
"# Convert seconds to frames\n", | ||
"num_interpolation_steps = [(b-a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]\n", | ||
"\n", | ||
"video_path = pipeline.walk(\n", | ||
" p_params,\n", | ||
" prompts=['blueberry spaghetti', 'strawberry spaghetti'],\n", | ||
" seeds=[42, 1337],\n", | ||
" num_interpolation_steps=num_interpolation_steps,\n", | ||
" height=512, # use multiples of 64\n", | ||
" width=512, # use multiples of 64\n", | ||
" audio_filepath='music/thoughts.mp3', # Use your own file\n", | ||
" audio_start_sec=audio_offsets[0], # Start second of the provided audio\n", | ||
" fps=fps, # important to set yourself based on the num_interpolation_steps you defined\n", | ||
" batch_size=2, # in TPU-v2 typically maximum of 3 for 512x512\n", | ||
" output_dir='./dreams', # Where images will be saved\n", | ||
" name=None, # Subdir of output dir. will be timestamp by default\n", | ||
" jit=True # To use all TPU cores\n", | ||
")\n", | ||
"visualize_video_colab(video_path)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "TPU", | ||
"colab": { | ||
"provenance": [], | ||
"include_colab_link": true | ||
}, | ||
"gpuClass": "standard", | ||
"kernelspec": { | ||
"display_name": "Python 3.9.12 ('base')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "7d7b96a25c39fa7937ff3ab94e1dd8c63b93cb924b8f0093093c6266e25a78bc" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.