-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax Implementation for TPU support #140
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
c89e061
MVP without testing :p
charlielito 3ad0477
Fix imports
charlielito 43c4cb2
Add test script
charlielito c0516e1
Fix some issues
charlielito 3815019
Full support jit tested
charlielito e724b58
Delete test file
charlielito 9815b58
Fix batch size > 1
charlielito 10232e8
Add TODOs and handle num_neg with an NotImplented Error
charlielito dfb1967
Adapt app to run video generation with Flas pipeline
charlielito 481dc66
Add colab in readme
charlielito f95cacc
Created using Colaboratory
charlielito f20a0b3
Created using Colaboratory
charlielito 6be236a
Add readme url with future notebook
charlielito f824229
Merge branch 'feature/flax_tpu' of github.com:charlielito/stable-diff…
charlielito File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,389 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/charlielito/stable-diffusion-videos/blob/feature%2Fflax_tpu/flax_stable_diffusion_videos.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "z4GhhH25OdYq" | ||
}, | ||
"source": [ | ||
"# Flax Stable Diffusion Videos\n", | ||
"\n", | ||
"This notebook allows you to generate videos by interpolating the latent space of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) using TPU for faster inference.\n", | ||
"\n", | ||
"In comparison with standard Colab GPU, this runs ~6x faster after the first run. The first run is comparable to the GPU version because it compiles the code.\n", | ||
"\n", | ||
"You can either dream up different versions of the same prompt, or morph between different text prompts (with seeds set for each for reproducibility).\n", | ||
"\n", | ||
"If you like this notebook:\n", | ||
"- consider giving the [repo a star](https://github.com/nateraw/stable-diffusion-videos) ⭐️\n", | ||
"- consider following us on Github [@nateraw](https://github.com/nateraw) [@charlielito](https://github.com/charlielito)\n", | ||
"\n", | ||
"You can file any issues/feature requests [here](https://github.com/nateraw/stable-diffusion-videos/issues)\n", | ||
"\n", | ||
"Enjoy 🤗" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "dvdCBpWWOhW-" | ||
}, | ||
"source": [ | ||
"## Setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"#@title Set up JAX\n", | ||
"#@markdown If you see an error, make sure you are using a TPU backend. Select `Runtime` in the menu above, then select the option \"Change runtime type\" and then select `TPU` under the `Hardware accelerator` setting.\n", | ||
"!pip install --upgrade jax jaxlib \n", | ||
"\n", | ||
"import jax.tools.colab_tpu\n", | ||
"jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')\n", | ||
"\n", | ||
"!pip install flax diffusers transformers ftfy\n", | ||
"jax.devices()" | ||
], | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "5EZdSq4HtmcE" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "Xwfc0ej1L9A0" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"! pip install stable_diffusion_videos" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "H7UOKJhVOonb" | ||
}, | ||
"source": [ | ||
"## Run the App 🚀" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "g71hslP8OntM" | ||
}, | ||
"source": [ | ||
"### Load the Interface\n", | ||
"\n", | ||
"This step will take a couple minutes the first time you run it." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "bgSNS368L-DV" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"\n", | ||
"from jax import pmap\n", | ||
"from flax.jax_utils import replicate\n", | ||
"from flax.training.common_utils import shard\n", | ||
"from PIL import Image\n", | ||
"\n", | ||
"from stable_diffusion_videos import FlaxStableDiffusionWalkPipeline, Interface\n", | ||
"\n", | ||
"pipeline, params = FlaxStableDiffusionWalkPipeline.from_pretrained(\n", | ||
" \"CompVis/stable-diffusion-v1-4\", \n", | ||
" revision=\"bf16\", \n", | ||
" dtype=jnp.bfloat16\n", | ||
")\n", | ||
"p_params = replicate(params)\n", | ||
"\n", | ||
"interface = Interface(pipeline, params=p_params)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "kidtsR3c2P9Z" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Connect to Google Drive to Save Outputs\n", | ||
"\n", | ||
"#@markdown If you want to connect Google Drive, click the checkbox below and run this cell. You'll be prompted to authenticate.\n", | ||
"\n", | ||
"#@markdown If you just want to save your outputs in this Colab session, don't worry about this cell\n", | ||
"\n", | ||
"connect_google_drive = True #@param {type:\"boolean\"}\n", | ||
"\n", | ||
"#@markdown Then, in the interface, use this path as the `output` in the Video tab to save your videos to Google Drive:\n", | ||
"\n", | ||
"#@markdown > /content/gdrive/MyDrive/stable_diffusion_videos\n", | ||
"\n", | ||
"\n", | ||
"if connect_google_drive:\n", | ||
" from google.colab import drive\n", | ||
"\n", | ||
" drive.mount('/content/gdrive')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "VxjRVNnMOtgU" | ||
}, | ||
"source": [ | ||
"### Launch\n", | ||
"\n", | ||
"This cell launches a Gradio Interface. Here's how I suggest you use it:\n", | ||
"\n", | ||
"1. Use the \"Images\" tab to generate images you like.\n", | ||
" - Find two images you want to morph between\n", | ||
" - These images should use the same settings (guidance scale, height, width)\n", | ||
" - Keep track of the seeds/settings you used so you can reproduce them\n", | ||
"\n", | ||
"2. Generate videos using the \"Videos\" tab\n", | ||
" - Using the images you found from the step above, provide the prompts/seeds you recorded\n", | ||
" - Set the `num_interpolation_steps` - for testing you can use a small number like 3 or 5, but to get great results you'll want to use something larger (60-200 steps). \n", | ||
"\n", | ||
"💡 **Pro tip** - Click the link that looks like `https://<id-number>.gradio.app` below , and you'll be able to view it in full screen." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "8es3_onUOL3J" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"interface.launch(debug=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "mFCoTvlnPi4u" | ||
}, | ||
"source": [ | ||
"---" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "SjTQLCiLOWeo" | ||
}, | ||
"source": [ | ||
"## Use `walk` programmatically\n", | ||
"\n", | ||
"The other option is to not use the interface, and instead use `walk` programmatically. Here's how you would do that..." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "fGQPClGwOR9R" | ||
}, | ||
"source": [ | ||
"First we define a helper fn for visualizing videos in colab" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "GqTWc8ZhNeLU" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from IPython.display import HTML\n", | ||
"from base64 import b64encode\n", | ||
"\n", | ||
"def visualize_video_colab(video_path):\n", | ||
" mp4 = open(video_path,'rb').read()\n", | ||
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", | ||
" return HTML(\"\"\"\n", | ||
" <video width=400 controls>\n", | ||
" <source src=\"%s\" type=\"video/mp4\">\n", | ||
" </video>\n", | ||
" \"\"\" % data_url)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "Vd_RzwkoPM7X" | ||
}, | ||
"source": [ | ||
"Walk! 🚶♀️" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "Hv2wBZXXMQ-I" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"video_path = pipeline.walk(\n", | ||
" p_params,\n", | ||
" ['a cat', 'a dog'],\n", | ||
" [42, 1337],\n", | ||
" fps=5, # use 5 for testing, 25 or 30 for better quality\n", | ||
" num_interpolation_steps=30, # use 3-5 for testing, 30 or more for better results\n", | ||
" height=512, # use multiples of 64 if > 512. Multiples of 8 if < 512.\n", | ||
" width=512, # use multiples of 64 if > 512. Multiples of 8 if < 512.\n", | ||
" jit=True # To use all TPU cores\n", | ||
")\n", | ||
"visualize_video_colab(video_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "oLXULBMwSDnY" | ||
}, | ||
"source": [ | ||
"### Bonus! Music videos\n", | ||
"\n", | ||
"First, we'll need to install `youtube-dl`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"%%capture\n", | ||
"! pip install youtube-dl" | ||
], | ||
"metadata": { | ||
"id": "302zMC44aiC6" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Then, we can download an example music file. Here we download one from my soundcloud:" | ||
], | ||
"metadata": { | ||
"id": "Q3gCLCkLanzO" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"! youtube-dl -f bestaudio --extract-audio --audio-format mp3 --audio-quality 0 -o \"music/thoughts.%(ext)s\" https://soundcloud.com/nateraw/thoughts" | ||
], | ||
"metadata": { | ||
"id": "rEsTe_ujagE5" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from IPython.display import Audio\n", | ||
"\n", | ||
"Audio(filename='music/thoughts.mp3')" | ||
], | ||
"metadata": { | ||
"id": "RIKA-l5la28j" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "DsIxXFTKSG5j" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Seconds in the song\n", | ||
"audio_offsets = [7, 9]\n", | ||
"fps = 8\n", | ||
"\n", | ||
"# Convert seconds to frames\n", | ||
"num_interpolation_steps = [(b-a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]\n", | ||
"\n", | ||
"video_path = pipeline.walk(\n", | ||
" p_params,\n", | ||
" prompts=['blueberry spaghetti', 'strawberry spaghetti'],\n", | ||
" seeds=[42, 1337],\n", | ||
" num_interpolation_steps=num_interpolation_steps,\n", | ||
" height=512, # use multiples of 64\n", | ||
" width=512, # use multiples of 64\n", | ||
" audio_filepath='music/thoughts.mp3', # Use your own file\n", | ||
" audio_start_sec=audio_offsets[0], # Start second of the provided audio\n", | ||
" fps=fps, # important to set yourself based on the num_interpolation_steps you defined\n", | ||
" batch_size=2, # in TPU-v2 typically maximum of 3 for 512x512\n", | ||
" output_dir='./dreams', # Where images will be saved\n", | ||
" name=None, # Subdir of output dir. will be timestamp by default\n", | ||
" jit=True # To use all TPU cores\n", | ||
")\n", | ||
"visualize_video_colab(video_path)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "TPU", | ||
"colab": { | ||
"provenance": [], | ||
"include_colab_link": true | ||
}, | ||
"gpuClass": "standard", | ||
"kernelspec": { | ||
"display_name": "Python 3.9.12 ('base')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "7d7b96a25c39fa7937ff3ab94e1dd8c63b93cb924b8f0093093c6266e25a78bc" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥