diff --git a/README.md b/README.md
index 526ce88..2415bc6 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
# stable-diffusion-videos
Try it yourself in Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nateraw/stable-diffusion-videos/blob/main/stable_diffusion_videos.ipynb)
+TPU version (~x6 faster than standard colab GPUs): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nateraw/stable-diffusion-videos/blob/main/flax_stable_diffusion_videos.ipynb)
**Example** - morphing between "blueberry spaghetti" and "strawberry spaghetti"
diff --git a/flax_stable_diffusion_videos.ipynb b/flax_stable_diffusion_videos.ipynb
new file mode 100644
index 0000000..95ffcca
--- /dev/null
+++ b/flax_stable_diffusion_videos.ipynb
@@ -0,0 +1,389 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z4GhhH25OdYq"
+ },
+ "source": [
+ "# Flax Stable Diffusion Videos\n",
+ "\n",
+ "This notebook allows you to generate videos by interpolating the latent space of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) using TPU for faster inference.\n",
+ "\n",
+ "In comparison with standard Colab GPU, this runs ~6x faster after the first run. The first run is comparable to the GPU version because it compiles the code.\n",
+ "\n",
+ "You can either dream up different versions of the same prompt, or morph between different text prompts (with seeds set for each for reproducibility).\n",
+ "\n",
+ "If you like this notebook:\n",
+ "- consider giving the [repo a star](https://github.com/nateraw/stable-diffusion-videos) βοΈ\n",
+ "- consider following us on Github [@nateraw](https://github.com/nateraw) [@charlielito](https://github.com/charlielito)\n",
+ "\n",
+ "You can file any issues/feature requests [here](https://github.com/nateraw/stable-diffusion-videos/issues)\n",
+ "\n",
+ "Enjoy π€"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dvdCBpWWOhW-"
+ },
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Set up JAX\n",
+ "#@markdown If you see an error, make sure you are using a TPU backend. Select `Runtime` in the menu above, then select the option \"Change runtime type\" and then select `TPU` under the `Hardware accelerator` setting.\n",
+ "!pip install --upgrade jax jaxlib \n",
+ "\n",
+ "import jax.tools.colab_tpu\n",
+ "jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')\n",
+ "\n",
+ "!pip install flax diffusers transformers ftfy\n",
+ "jax.devices()"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "5EZdSq4HtmcE"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Xwfc0ej1L9A0"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "! pip install stable_diffusion_videos"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "H7UOKJhVOonb"
+ },
+ "source": [
+ "## Run the App π"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "g71hslP8OntM"
+ },
+ "source": [
+ "### Load the Interface\n",
+ "\n",
+ "This step will take a couple minutes the first time you run it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bgSNS368L-DV"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "from jax import pmap\n",
+ "from flax.jax_utils import replicate\n",
+ "from flax.training.common_utils import shard\n",
+ "from PIL import Image\n",
+ "\n",
+ "from stable_diffusion_videos import FlaxStableDiffusionWalkPipeline, Interface\n",
+ "\n",
+ "pipeline, params = FlaxStableDiffusionWalkPipeline.from_pretrained(\n",
+ " \"CompVis/stable-diffusion-v1-4\", \n",
+ " revision=\"bf16\", \n",
+ " dtype=jnp.bfloat16\n",
+ ")\n",
+ "p_params = replicate(params)\n",
+ "\n",
+ "interface = Interface(pipeline, params=p_params)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "kidtsR3c2P9Z"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Connect to Google Drive to Save Outputs\n",
+ "\n",
+ "#@markdown If you want to connect Google Drive, click the checkbox below and run this cell. You'll be prompted to authenticate.\n",
+ "\n",
+ "#@markdown If you just want to save your outputs in this Colab session, don't worry about this cell\n",
+ "\n",
+ "connect_google_drive = True #@param {type:\"boolean\"}\n",
+ "\n",
+ "#@markdown Then, in the interface, use this path as the `output` in the Video tab to save your videos to Google Drive:\n",
+ "\n",
+ "#@markdown > /content/gdrive/MyDrive/stable_diffusion_videos\n",
+ "\n",
+ "\n",
+ "if connect_google_drive:\n",
+ " from google.colab import drive\n",
+ "\n",
+ " drive.mount('/content/gdrive')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VxjRVNnMOtgU"
+ },
+ "source": [
+ "### Launch\n",
+ "\n",
+ "This cell launches a Gradio Interface. Here's how I suggest you use it:\n",
+ "\n",
+ "1. Use the \"Images\" tab to generate images you like.\n",
+ " - Find two images you want to morph between\n",
+ " - These images should use the same settings (guidance scale, height, width)\n",
+ " - Keep track of the seeds/settings you used so you can reproduce them\n",
+ "\n",
+ "2. Generate videos using the \"Videos\" tab\n",
+ " - Using the images you found from the step above, provide the prompts/seeds you recorded\n",
+ " - Set the `num_interpolation_steps` - for testing you can use a small number like 3 or 5, but to get great results you'll want to use something larger (60-200 steps). \n",
+ "\n",
+ "π‘ **Pro tip** - Click the link that looks like `https://.gradio.app` below , and you'll be able to view it in full screen."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8es3_onUOL3J"
+ },
+ "outputs": [],
+ "source": [
+ "interface.launch(debug=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mFCoTvlnPi4u"
+ },
+ "source": [
+ "---"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SjTQLCiLOWeo"
+ },
+ "source": [
+ "## Use `walk` programmatically\n",
+ "\n",
+ "The other option is to not use the interface, and instead use `walk` programmatically. Here's how you would do that..."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fGQPClGwOR9R"
+ },
+ "source": [
+ "First we define a helper fn for visualizing videos in colab"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GqTWc8ZhNeLU"
+ },
+ "outputs": [],
+ "source": [
+ "from IPython.display import HTML\n",
+ "from base64 import b64encode\n",
+ "\n",
+ "def visualize_video_colab(video_path):\n",
+ " mp4 = open(video_path,'rb').read()\n",
+ " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
+ " return HTML(\"\"\"\n",
+ " \n",
+ " \"\"\" % data_url)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Vd_RzwkoPM7X"
+ },
+ "source": [
+ "Walk! πΆββοΈ"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Hv2wBZXXMQ-I"
+ },
+ "outputs": [],
+ "source": [
+ "video_path = pipeline.walk(\n",
+ " p_params,\n",
+ " ['a cat', 'a dog'],\n",
+ " [42, 1337],\n",
+ " fps=5, # use 5 for testing, 25 or 30 for better quality\n",
+ " num_interpolation_steps=30, # use 3-5 for testing, 30 or more for better results\n",
+ " height=512, # use multiples of 64 if > 512. Multiples of 8 if < 512.\n",
+ " width=512, # use multiples of 64 if > 512. Multiples of 8 if < 512.\n",
+ " jit=True # To use all TPU cores\n",
+ ")\n",
+ "visualize_video_colab(video_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oLXULBMwSDnY"
+ },
+ "source": [
+ "### Bonus! Music videos\n",
+ "\n",
+ "First, we'll need to install `youtube-dl`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "! pip install youtube-dl"
+ ],
+ "metadata": {
+ "id": "302zMC44aiC6"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Then, we can download an example music file. Here we download one from my soundcloud:"
+ ],
+ "metadata": {
+ "id": "Q3gCLCkLanzO"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "! youtube-dl -f bestaudio --extract-audio --audio-format mp3 --audio-quality 0 -o \"music/thoughts.%(ext)s\" https://soundcloud.com/nateraw/thoughts"
+ ],
+ "metadata": {
+ "id": "rEsTe_ujagE5"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from IPython.display import Audio\n",
+ "\n",
+ "Audio(filename='music/thoughts.mp3')"
+ ],
+ "metadata": {
+ "id": "RIKA-l5la28j"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DsIxXFTKSG5j"
+ },
+ "outputs": [],
+ "source": [
+ "# Seconds in the song\n",
+ "audio_offsets = [7, 9]\n",
+ "fps = 8\n",
+ "\n",
+ "# Convert seconds to frames\n",
+ "num_interpolation_steps = [(b-a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]\n",
+ "\n",
+ "video_path = pipeline.walk(\n",
+ " p_params,\n",
+ " prompts=['blueberry spaghetti', 'strawberry spaghetti'],\n",
+ " seeds=[42, 1337],\n",
+ " num_interpolation_steps=num_interpolation_steps,\n",
+ " height=512, # use multiples of 64\n",
+ " width=512, # use multiples of 64\n",
+ " audio_filepath='music/thoughts.mp3', # Use your own file\n",
+ " audio_start_sec=audio_offsets[0], # Start second of the provided audio\n",
+ " fps=fps, # important to set yourself based on the num_interpolation_steps you defined\n",
+ " batch_size=2, # in TPU-v2 typically maximum of 3 for 512x512\n",
+ " output_dir='./dreams', # Where images will be saved\n",
+ " name=None, # Subdir of output dir. will be timestamp by default\n",
+ " jit=True # To use all TPU cores\n",
+ ")\n",
+ "visualize_video_colab(video_path)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "TPU",
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3.9.12 ('base')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "7d7b96a25c39fa7937ff3ab94e1dd8c63b93cb924b8f0093093c6266e25a78bc"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/stable_diffusion_videos/__init__.py b/stable_diffusion_videos/__init__.py
index 0ba78e4..36b4873 100644
--- a/stable_diffusion_videos/__init__.py
+++ b/stable_diffusion_videos/__init__.py
@@ -61,7 +61,9 @@ def _attach(package_name, submodules=None, submod_attrs=None):
else:
submodules = set(submodules)
- attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs}
+ attr_to_modules = {
+ attr: mod for mod, attrs in submod_attrs.items() for attr in attrs
+ }
__all__ = list(submodules | attr_to_modules.keys())
@@ -110,6 +112,7 @@ def __dir__():
"make_video_pyav",
"get_timesteps_arr",
],
+ "flax_stable_diffusion_pipeline": ["FlaxStableDiffusionWalkPipeline"],
"upsampling": ["RealESRGANModel"],
},
)
diff --git a/stable_diffusion_videos/app.py b/stable_diffusion_videos/app.py
index 96aa9a6..d64e9c7 100644
--- a/stable_diffusion_videos/app.py
+++ b/stable_diffusion_videos/app.py
@@ -6,46 +6,53 @@
class Interface:
- def __init__(self, pipeline):
+ def __init__(self, pipeline, params=None):
self.pipeline = pipeline
+ self.params = params # params in case we are using Flax pipeline
self.interface_images = gr.Interface(
self.fn_images,
inputs=[
- gr.Textbox("blueberry spaghetti", label='Prompt'),
- gr.Slider(1, 24, 1, step=1, label='Batch size'),
- gr.Slider(1, 16, 1, step=1, label='# Batches'),
- gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
- gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
- gr.Slider(512, 1024, 512, step=64, label='Height'),
- gr.Slider(512, 1024, 512, step=64, label='Width'),
- gr.Checkbox(False, label='Upsample'),
- gr.Textbox("./images", label='Output directory to save results to'),
+ gr.Textbox("blueberry spaghetti", label="Prompt"),
+ gr.Slider(1, 24, 1, step=1, label="Batch size"),
+ gr.Slider(1, 16, 1, step=1, label="# Batches"),
+ gr.Slider(10, 100, 50, step=1, label="# Inference Steps"),
+ gr.Slider(5.0, 15.0, 7.5, step=0.5, label="Guidance Scale"),
+ gr.Slider(512, 1024, 512, step=64, label="Height"),
+ gr.Slider(512, 1024, 512, step=64, label="Width"),
+ gr.Checkbox(False, label="Upsample"),
+ gr.Textbox("./images", label="Output directory to save results to"),
# gr.Checkbox(False, label='Push results to Hugging Face Hub'),
# gr.Textbox("", label='Hugging Face Repo ID to push images to'),
],
- outputs=gr.Gallery(),
+ outputs=gr.Gallery() if self.params is None else gr.Textbox(),
)
self.interface_videos = gr.Interface(
self.fn_videos,
inputs=[
- gr.Textbox("blueberry spaghetti\nstrawberry spaghetti", lines=2, label='Prompts, separated by new line'),
- gr.Textbox("42\n1337", lines=2, label='Seeds, separated by new line'),
- gr.Slider(3, 1000, 5, step=1, label='# Interpolation Steps between prompts'),
- gr.Slider(3, 60, 5, step=1, label='Output Video FPS'),
- gr.Slider(1, 24, 1, step=1, label='Batch size'),
- gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
- gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
- gr.Slider(512, 1024, 512, step=64, label='Height'),
- gr.Slider(512, 1024, 512, step=64, label='Width'),
- gr.Checkbox(False, label='Upsample'),
- gr.Textbox("./dreams", label='Output directory to save results to'),
+ gr.Textbox(
+ "blueberry spaghetti\nstrawberry spaghetti",
+ lines=2,
+ label="Prompts, separated by new line",
+ ),
+ gr.Textbox("42\n1337", lines=2, label="Seeds, separated by new line"),
+ gr.Slider(
+ 3, 1000, 5, step=1, label="# Interpolation Steps between prompts"
+ ),
+ gr.Slider(3, 60, 5, step=1, label="Output Video FPS"),
+ gr.Slider(1, 24, 1, step=1, label="Batch size"),
+ gr.Slider(10, 100, 50, step=1, label="# Inference Steps"),
+ gr.Slider(5.0, 15.0, 7.5, step=0.5, label="Guidance Scale"),
+ gr.Slider(512, 1024, 512, step=64, label="Height"),
+ gr.Slider(512, 1024, 512, step=64, label="Width"),
+ gr.Checkbox(False, label="Upsample"),
+ gr.Textbox("./dreams", label="Output directory to save results to"),
],
outputs=gr.Video(),
)
self.interface = gr.TabbedInterface(
[self.interface_images, self.interface_videos],
- ['Images!', 'Videos!'],
+ ["Images!", "Videos!"],
)
def fn_videos(
@@ -62,10 +69,10 @@ def fn_videos(
upsample,
output_dir,
):
- prompts = [x.strip() for x in prompts.split('\n') if x.strip()]
- seeds = [int(x.strip()) for x in seeds.split('\n') if x.strip()]
+ prompts = [x.strip() for x in prompts.split("\n") if x.strip()]
+ seeds = [int(x.strip()) for x in seeds.split("\n") if x.strip()]
- return self.pipeline.walk(
+ kwargs = dict(
prompts=prompts,
seeds=seeds,
num_interpolation_steps=num_interpolation_steps,
@@ -76,8 +83,12 @@ def fn_videos(
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
upsample=upsample,
- batch_size=batch_size
+ batch_size=batch_size,
)
+ if self.params is not None:
+ # Assume Flax pipeline, force jit, params should be already replicated
+ kwargs.update(dict(params=self.params, jit=True))
+ return self.pipeline.walk(**kwargs)
def fn_images(
self,
@@ -93,6 +104,8 @@ def fn_images(
repo_id=None,
push_to_hub=False,
):
+ if self.params is not None:
+ return "Single image generation is not supported for Flax yet. Go to the videos tab."
image_filepaths = generate_images(
self.pipeline,
prompt,
@@ -101,7 +114,7 @@ def fn_images(
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
output_dir=output_dir,
- image_file_ext='.jpg',
+ image_file_ext=".jpg",
upsample=upsample,
height=height,
width=width,
@@ -112,4 +125,4 @@ def fn_images(
return [(x, Path(x).stem) for x in sorted(image_filepaths)]
def launch(self, *args, **kwargs):
- self.interface.launch(*args, **kwargs)
\ No newline at end of file
+ self.interface.launch(*args, **kwargs)
diff --git a/stable_diffusion_videos/flax_stable_diffusion_pipeline.py b/stable_diffusion_videos/flax_stable_diffusion_pipeline.py
new file mode 100644
index 0000000..57136dc
--- /dev/null
+++ b/stable_diffusion_videos/flax_stable_diffusion_pipeline.py
@@ -0,0 +1,1066 @@
+import json
+import time
+import warnings
+from functools import partial
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+import jax
+import jax.numpy as jnp
+import librosa
+import numpy as np
+import torch
+from diffusers.models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
+from diffusers.pipeline_flax_utils import FlaxDiffusionPipeline
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker_flax import (
+ FlaxStableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+)
+from diffusers.utils import deprecate, logging
+from flax.core.frozen_dict import FrozenDict
+from flax.jax_utils import unreplicate
+from flax.training.common_utils import shard
+from packaging import version
+from PIL import Image
+from torchvision.io import write_video
+from torchvision.transforms.functional import pil_to_tensor
+from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Set to True to use python for loop instead of jax.fori_loop for easier debugging
+DEBUG = False
+NUM_TPU_CORES = jax.device_count()
+
+from .upsampling import RealESRGANModel
+
+
+def pad_along_axis(array: np.ndarray, pad_size: int, axis: int = 0) -> np.ndarray:
+ if pad_size <= 0:
+ return array
+ npad = [(0, 0)] * array.ndim
+ npad[axis] = (0, pad_size)
+ return np.pad(array, pad_width=npad, mode="constant", constant_values=0)
+
+
+def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
+ y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
+
+ # librosa.stft hardcoded defaults...
+ # n_fft defaults to 2048
+ # hop length is win_length // 4
+ # win_length defaults to n_fft
+ D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)
+
+ # Extract percussive elements
+ D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
+ y_percussive = librosa.istft(D_percussive, length=len(y))
+
+ # Get normalized melspectrogram
+ spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
+ spec_max = np.amax(spec_raw, axis=0)
+ spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
+
+ # Resize cumsum of spec norm to our desired number of interpolation frames
+ x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
+ y_norm = np.cumsum(spec_norm)
+ y_norm /= y_norm[-1]
+ x_resize = np.linspace(0, y_norm.shape[-1], int(duration * fps))
+
+ T = np.interp(x_resize, x_norm, y_norm)
+
+ # Apply smoothing
+ return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth
+
+
+def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
+ """helper function to spherically interpolate two arrays v1 v2"""
+
+ # if not isinstance(v0, np.ndarray):
+ # inputs_are_torch = True
+ # input_device = v0.device
+ # v0 = v0.cpu().numpy()
+ # v1 = v1.cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ # if inputs_are_torch:
+ # v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+def make_video_pyav(
+ frames_or_frame_dir: Union[str, Path, torch.Tensor],
+ audio_filepath: Union[str, Path] = None,
+ fps: int = 30,
+ audio_offset: int = 0,
+ audio_duration: int = 2,
+ sr: int = 22050,
+ output_filepath: Union[str, Path] = "output.mp4",
+ glob_pattern: str = "*.png",
+):
+ """
+ TODO - docstring here
+
+ frames_or_frame_dir: (Union[str, Path, torch.Tensor]):
+ Either a directory of images, or a tensor of shape (T, C, H, W) in range [0, 255].
+ """
+
+ # Torchvision write_video doesn't support pathlib paths
+ output_filepath = str(output_filepath)
+
+ if isinstance(frames_or_frame_dir, (str, Path)):
+ frames = None
+ for img in sorted(Path(frames_or_frame_dir).glob(glob_pattern)):
+ frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
+ frames = frame if frames is None else torch.cat([frames, frame])
+ else:
+ frames = frames_or_frame_dir
+
+ # TCHW -> THWC
+ frames = frames.permute(0, 2, 3, 1)
+
+ if audio_filepath:
+ # Read audio, convert to tensor
+ audio, sr = librosa.load(
+ audio_filepath,
+ sr=sr,
+ mono=True,
+ offset=audio_offset,
+ duration=audio_duration,
+ )
+ audio_tensor = torch.tensor(audio).unsqueeze(0)
+
+ write_video(
+ output_filepath,
+ frames,
+ fps=fps,
+ audio_array=audio_tensor,
+ audio_fps=sr,
+ audio_codec="aac",
+ options={"crf": "10", "pix_fmt": "yuv420p"},
+ )
+ else:
+ write_video(
+ output_filepath,
+ frames,
+ fps=fps,
+ options={"crf": "10", "pix_fmt": "yuv420p"},
+ )
+
+ return output_filepath
+
+
+class FlaxStableDiffusionWalkPipeline(FlaxDiffusionPipeline):
+ r"""
+ Pipeline for generating videos by interpolating Stable Diffusion's latent space.
+
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`FlaxAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`FlaxCLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
+ [`FlaxDPMSolverMultistepScheduler`].
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: FlaxAutoencoderKL,
+ text_encoder: FlaxCLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: FlaxUNet2DConditionModel,
+ scheduler: Union[
+ FlaxDDIMScheduler,
+ FlaxPNDMScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ ],
+ safety_checker: FlaxStableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ super().__init__()
+ self.dtype = dtype
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(
+ unet.config, "_diffusers_version"
+ ) and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse(
+ "0.9.0.dev0"
+ )
+ is_unet_sample_size_less_64 = (
+ hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate(
+ "sample_size<64", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def prepare_inputs(self, prompt: Union[str, List[str]]):
+ if not isinstance(prompt, (str, list)):
+ raise ValueError(
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+ )
+
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ return text_input.input_ids
+
+ def _get_has_nsfw_concepts(self, features, params):
+ has_nsfw_concepts = self.safety_checker(features, params)
+ return has_nsfw_concepts
+
+ def _run_safety_checker(self, images, safety_model_params, jit=False):
+ # safety_model_params should already be replicated when jit is True
+ pil_images = [Image.fromarray(image) for image in images]
+ features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
+
+ if jit:
+ features = shard(features)
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(
+ self, features, safety_model_params
+ )
+ has_nsfw_concepts = unshard(has_nsfw_concepts)
+ safety_model_params = unreplicate(safety_model_params)
+ else:
+ has_nsfw_concepts = self._get_has_nsfw_concepts(
+ features, safety_model_params
+ )
+
+ images_was_copied = False
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ if not images_was_copied:
+ images_was_copied = True
+ images = images.copy()
+
+ images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
+
+ if any(has_nsfw_concepts):
+ warnings.warn(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead. Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ def _generate(
+ self,
+ prompt_ids: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.PRNGKey,
+ num_inference_steps: int = 50,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ guidance_scale: float = 7.5,
+ latents: Optional[jnp.array] = None,
+ neg_prompt_ids: jnp.array = None,
+ text_embeddings: Optional[jnp.array] = None,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+ )
+
+ if text_embeddings is None:
+ # get prompt text embeddings
+ text_embeddings = self.text_encoder(
+ prompt_ids, params=params["text_encoder"]
+ )[0]
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
+ batch_size = prompt_ids.shape[0]
+
+ max_length = prompt_ids.shape[-1]
+ else:
+ batch_size = text_embeddings.shape[0]
+ # TODO: check if this is enough
+ max_length = self.tokenizer.model_max_length
+
+ if neg_prompt_ids is None:
+ uncond_input = self.tokenizer(
+ [""] * batch_size,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="np",
+ ).input_ids
+ else:
+ uncond_input = neg_prompt_ids
+ uncond_embeddings = self.text_encoder(
+ uncond_input, params=params["text_encoder"]
+ )[0]
+ context = jnp.concatenate([uncond_embeddings, text_embeddings])
+
+ latents_shape = (
+ batch_size,
+ self.unet.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if latents is None:
+ latents = jax.random.normal(
+ prng_seed, shape=latents_shape, dtype=jnp.float32
+ )
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(
+ f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
+ )
+
+ def loop_body(step, args):
+ latents, scheduler_state = args
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = jnp.concatenate([latents] * 2)
+
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
+ timestep = jnp.broadcast_to(t, latents_input.shape[0])
+
+ latents_input = self.scheduler.scale_model_input(
+ scheduler_state, latents_input, t
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet.apply(
+ {"params": params["unet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ ).sample
+ # perform guidance
+ noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_prediction_text - noise_pred_uncond
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, scheduler_state = self.scheduler.step(
+ scheduler_state, noise_pred, t, latents
+ ).to_tuple()
+ return latents, scheduler_state
+
+ scheduler_state = self.scheduler.set_timesteps(
+ params["scheduler"],
+ num_inference_steps=num_inference_steps,
+ shape=latents.shape,
+ )
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ if DEBUG:
+ # run with python for loop
+ for i in range(num_inference_steps):
+ latents, scheduler_state = loop_body(i, (latents, scheduler_state))
+ else:
+ latents, _ = jax.lax.fori_loop(
+ 0, num_inference_steps, loop_body, (latents, scheduler_state)
+ )
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.apply(
+ {"params": params["vae"]}, latents, method=self.vae.decode
+ ).sample
+
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
+ return image
+
+ def __call__(
+ self,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.PRNGKey,
+ prompt_ids: Optional[jnp.array] = None,
+ num_inference_steps: int = 50,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ guidance_scale: Union[float, jnp.array] = 7.5,
+ latents: jnp.array = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ jit: bool = False,
+ neg_prompt_ids: jnp.array = None,
+ text_embeddings: Optional[jnp.array] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ params (`Union[Dict, FrozenDict]`): The model parameters.
+ prng_seed (`jax.random.PRNGKey`): The random seed used for sampling the noise.
+ prompt_ids (`jnp.array`, *optional*, defaults to `None`):
+ The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ latents (`jnp.array`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
+ by sampling using the supplied random `generator`.
+ jit (`bool`, defaults to `False`):
+ Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
+ exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
+ a plain tuple.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ neg_prompt_ids (`jnp.array`, *optional*):
+ The prompt or prompts ids not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ text_embeddings (`jnp.array`, *optional*, defaults to `None`):
+ Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
+ `prompt_ids` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
+ the supplied `prompt`.
+
+ Returns:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if prompt_ids is None and text_embeddings is None:
+ raise ValueError(
+ "Either `prompt_ids` or `text_embeddings` must be provided."
+ )
+
+ if jit:
+ images = _p_generate(
+ self,
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ text_embeddings,
+ )
+ else:
+ images = self._generate(
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ text_embeddings,
+ )
+
+ if self.safety_checker is not None:
+ safety_params = params["safety_checker"]
+ images_uint8_casted = (images * 255).round().astype("uint8")
+ num_devices, batch_size = images.shape[:2]
+
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(
+ num_devices * batch_size, height, width, 3
+ )
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(
+ images_uint8_casted, safety_params, jit
+ )
+ images = np.asarray(images).reshape(
+ num_devices * batch_size, height, width, 3
+ )
+
+ # block images
+ if any(has_nsfw_concept):
+ for i, is_nsfw in enumerate(has_nsfw_concept):
+ if is_nsfw:
+ images[i] = np.asarray(images_uint8_casted[i])
+
+ images = images.reshape(num_devices, batch_size, height, width, 3)
+ else:
+ images = np.asarray(images)
+ has_nsfw_concept = False
+
+ if jit:
+ images = unshard(images)
+
+ # Convert to PIL
+ if output_type == "pil":
+ images = self.numpy_to_pil(images)
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return FlaxStableDiffusionPipelineOutput(
+ images=images, nsfw_content_detected=has_nsfw_concept
+ )
+
+ def generate_inputs(
+ self, params, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size
+ ):
+ embeds_a = self.embed_text(params, prompt_a)
+ embeds_b = self.embed_text(params, prompt_b)
+ latents_dtype = embeds_a.dtype
+ latents_a = self.init_noise(seed_a, noise_shape, latents_dtype)
+ latents_b = self.init_noise(seed_b, noise_shape, latents_dtype)
+
+ batch_idx = 0
+ embeds_batch, noise_batch = None, None
+ for i, t in enumerate(T):
+ embeds = slerp(float(t), embeds_a, embeds_b)
+ noise = slerp(float(t), latents_a, latents_b)
+
+ embeds_batch = (
+ embeds
+ if embeds_batch is None
+ else np.concatenate([embeds_batch, embeds])
+ )
+ noise_batch = (
+ noise if noise_batch is None else np.concatenate([noise_batch, noise])
+ )
+ batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
+ if not batch_is_ready:
+ continue
+ yield batch_idx, embeds_batch, noise_batch
+ batch_idx += 1
+ del embeds_batch, noise_batch
+ # torch.cuda.empty_cache()
+ embeds_batch, noise_batch = None, None
+
+ def make_clip_frames(
+ self,
+ params: Union[Dict, FrozenDict],
+ prompt_a: str,
+ prompt_b: str,
+ seed_a: int,
+ seed_b: int,
+ num_interpolation_steps: int = 5,
+ save_path: Union[str, Path] = "outputs/",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ eta: float = 0.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ upsample: bool = False,
+ batch_size: int = 1,
+ image_file_ext: str = ".png",
+ T: np.ndarray = None,
+ skip: int = 0,
+ negative_prompt: str = None,
+ jit: bool = False,
+ ):
+ if negative_prompt is not None:
+ raise NotImplementedError(
+ "Negative prompt is not supported for make_clip_frames yet."
+ )
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ save_path = Path(save_path)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ T = T if T is not None else np.linspace(0.0, 1.0, num_interpolation_steps)
+ if T.shape[0] != num_interpolation_steps:
+ raise ValueError(
+ f"Unexpected T shape, got {T.shape}, expected dim 0 to be {num_interpolation_steps}"
+ )
+
+ if upsample:
+ if getattr(self, "upsampler", None) is None:
+ # TODO: port to flax
+ self.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
+ if not torch.cuda.is_available():
+ logger.warning(
+ "Upsampling is recommended to be done on a GPU, as it is very slow on CPU"
+ )
+ else:
+ self.upsampler = self.upsampler.cuda()
+
+ seed_a = jax.random.PRNGKey(seed_a)
+ seed_b = jax.random.PRNGKey(seed_b)
+
+ text_encoder_params = params["text_encoder"]
+ if jit: # if jit, asume params are replicated
+ # for encoding de prompts we run it on a single device
+ text_encoder_params = unreplicate(text_encoder_params)
+
+ batch_generator = self.generate_inputs(
+ text_encoder_params,
+ prompt_a,
+ prompt_b,
+ seed_a,
+ seed_b,
+ (1, self.unet.in_channels, height // 8, width // 8),
+ T[skip:],
+ batch_size=NUM_TPU_CORES * batch_size if jit else batch_size,
+ )
+
+ # TODO: convert negative_prompt to neg_prompt_ids
+
+ frame_index = skip
+ for _, embeds_batch, noise_batch in batch_generator:
+ if jit:
+ padded = False
+ # Check if embeds_batch 0 dimension is multiple of NUM_TPU_CORES, if not pad
+ if embeds_batch.shape[0] % NUM_TPU_CORES != 0:
+ padded = True
+ pad_size = NUM_TPU_CORES - (embeds_batch.shape[0] % NUM_TPU_CORES)
+ # Pad embeds_batch and noise_batch with zeros in batch dimension
+ embeds_batch = pad_along_axis(embeds_batch, pad_size, axis=0)
+ noise_batch = pad_along_axis(noise_batch, pad_size, axis=0)
+ embeds_batch = shard(embeds_batch)
+ noise_batch = shard(noise_batch)
+
+ outputs = self(
+ params,
+ prng_seed=None,
+ latents=noise_batch,
+ text_embeddings=embeds_batch,
+ height=height,
+ width=width,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ output_type="pil" if not upsample else "numpy",
+ neg_prompt_ids=negative_prompt,
+ jit=jit,
+ )["images"]
+
+ if jit:
+ # check if we padded and remove that padding from outputs
+ if padded:
+ outputs = outputs[:-pad_size]
+
+ for image in outputs:
+ frame_filepath = save_path / (
+ f"frame%06d{image_file_ext}" % frame_index
+ )
+ # image = image if not upsample else self.upsampler(image)
+ image.save(frame_filepath)
+ frame_index += 1
+
+ def walk(
+ self,
+ params: Union[Dict, FrozenDict],
+ prompts: Optional[List[str]] = None,
+ seeds: Optional[List[int]] = None,
+ num_interpolation_steps: Optional[
+ Union[int, List[int]]
+ ] = 5, # int or list of int
+ output_dir: Optional[str] = "./dreams",
+ name: Optional[str] = None,
+ image_file_ext: Optional[str] = ".png",
+ fps: Optional[int] = 30,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ upsample: Optional[bool] = False,
+ batch_size: Optional[int] = 1,
+ resume: Optional[bool] = False,
+ audio_filepath: str = None,
+ audio_start_sec: Optional[Union[int, float]] = None,
+ margin: Optional[float] = 1.0,
+ smooth: Optional[float] = 0.0,
+ negative_prompt: Optional[str] = None,
+ jit: bool = False,
+ ):
+ """Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
+ video to interpolate to the intensity of the audio.
+
+ Args:
+ prompts (Optional[List[str]], optional):
+ list of text prompts. Defaults to None.
+ seeds (Optional[List[int]], optional):
+ list of random seeds corresponding to prompts. Defaults to None.
+ num_interpolation_steps (Union[int, List[int]], *optional*):
+ How many interpolation steps between each prompt. Defaults to None.
+ output_dir (Optional[str], optional):
+ Where to save the video. Defaults to './dreams'.
+ name (Optional[str], optional):
+ Name of the subdirectory of output_dir. Defaults to None.
+ image_file_ext (Optional[str], *optional*, defaults to '.png'):
+ The extension to use when writing video frames.
+ fps (Optional[int], *optional*, defaults to 30):
+ The frames per second in the resulting output videos.
+ num_inference_steps (Optional[int], *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (Optional[float], *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (Optional[float], *optional*, defaults to 0.0):
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ height (Optional[int], *optional*, defaults to None):
+ height of the images to generate.
+ width (Optional[int], *optional*, defaults to None):
+ width of the images to generate.
+ upsample (Optional[bool], *optional*, defaults to False):
+ When True, upsamples images with realesrgan.
+ batch_size (Optional[int], *optional*, defaults to 1):
+ Number of images to generate at once.
+ resume (Optional[bool], *optional*, defaults to False):
+ When True, resumes from the last frame in the output directory based
+ on available prompt config. Requires you to provide the `name` argument.
+ audio_filepath (str, *optional*, defaults to None):
+ Optional path to an audio file to influence the interpolation rate.
+ audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
+ Global start time of the provided audio_filepath.
+ margin (Optional[float], *optional*, defaults to 1.0):
+ Margin from librosa hpss to use for audio interpolation.
+ smooth (Optional[float], *optional*, defaults to 0.0):
+ Smoothness of the audio interpolation. 1.0 means linear interpolation.
+ negative_prompt (Optional[str], *optional*, defaults to None):
+ Optional negative prompt to use. Same across all prompts.
+
+ This function will create sub directories for each prompt and seed pair.
+
+ For example, if you provide the following prompts and seeds:
+
+ ```
+ prompts = ['a dog', 'a cat', 'a bird']
+ seeds = [1, 2, 3]
+ num_interpolation_steps = 5
+ output_dir = 'output_dir'
+ name = 'name'
+ fps = 5
+ ```
+
+ Then the following directories will be created:
+
+ ```
+ output_dir
+ βββ name
+ β βββ name_000000
+ β β βββ frame000000.png
+ β β βββ ...
+ β β βββ frame000004.png
+ β β βββ name_000000.mp4
+ β βββ name_000001
+ β β βββ frame000000.png
+ β β βββ ...
+ β β βββ frame000004.png
+ β β βββ name_000001.mp4
+ β βββ ...
+ β βββ name.mp4
+ | |ββ prompt_config.json
+ ```
+
+ Returns:
+ str: The resulting video filepath. This video includes all sub directories' video clips.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ output_path = Path(output_dir)
+
+ name = name or time.strftime("%Y%m%d-%H%M%S")
+ save_path_root = output_path / name
+ save_path_root.mkdir(parents=True, exist_ok=True)
+
+ # Where the final video of all the clips combined will be saved
+ output_filepath = save_path_root / f"{name}.mp4"
+
+ # If using same number of interpolation steps between, we turn into list
+ if not resume and isinstance(num_interpolation_steps, int):
+ num_interpolation_steps = [num_interpolation_steps] * (len(prompts) - 1)
+
+ if not resume:
+ audio_start_sec = audio_start_sec or 0
+
+ # Save/reload prompt config
+ prompt_config_path = save_path_root / "prompt_config.json"
+ if not resume:
+ prompt_config_path.write_text(
+ json.dumps(
+ dict(
+ prompts=prompts,
+ seeds=seeds,
+ num_interpolation_steps=num_interpolation_steps,
+ fps=fps,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ upsample=upsample,
+ height=height,
+ width=width,
+ audio_filepath=audio_filepath,
+ audio_start_sec=audio_start_sec,
+ negative_prompt=negative_prompt,
+ ),
+ indent=2,
+ sort_keys=False,
+ )
+ )
+ else:
+ data = json.load(open(prompt_config_path))
+ prompts = data["prompts"]
+ seeds = data["seeds"]
+ num_interpolation_steps = data["num_interpolation_steps"]
+ fps = data["fps"]
+ num_inference_steps = data["num_inference_steps"]
+ guidance_scale = data["guidance_scale"]
+ eta = data["eta"]
+ upsample = data["upsample"]
+ height = data["height"]
+ width = data["width"]
+ audio_filepath = data["audio_filepath"]
+ audio_start_sec = data["audio_start_sec"]
+ negative_prompt = data.get("negative_prompt", None)
+
+ for i, (prompt_a, prompt_b, seed_a, seed_b, num_step) in enumerate(
+ zip(prompts, prompts[1:], seeds, seeds[1:], num_interpolation_steps)
+ ):
+ # {name}_000000 / {name}_000001 / ...
+ save_path = save_path_root / f"{name}_{i:06d}"
+
+ # Where the individual clips will be saved
+ step_output_filepath = save_path / f"{name}_{i:06d}.mp4"
+
+ # Determine if we need to resume from a previous run
+ skip = 0
+ if resume:
+ if step_output_filepath.exists():
+ print(f"Skipping {save_path} because frames already exist")
+ continue
+
+ existing_frames = sorted(save_path.glob(f"*{image_file_ext}"))
+ if existing_frames:
+ skip = int(existing_frames[-1].stem[-6:]) + 1
+ if skip + 1 >= num_step:
+ print(f"Skipping {save_path} because frames already exist")
+ continue
+ print(f"Resuming {save_path.name} from frame {skip}")
+
+ audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
+ audio_duration = num_step / fps
+
+ self.make_clip_frames(
+ params,
+ prompt_a,
+ prompt_b,
+ seed_a,
+ seed_b,
+ num_interpolation_steps=num_step,
+ save_path=save_path,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ height=height,
+ width=width,
+ upsample=upsample,
+ batch_size=batch_size,
+ T=get_timesteps_arr(
+ audio_filepath,
+ offset=audio_offset,
+ duration=audio_duration,
+ fps=fps,
+ margin=margin,
+ smooth=smooth,
+ )
+ if audio_filepath
+ else None,
+ skip=skip,
+ negative_prompt=negative_prompt,
+ jit=jit,
+ )
+ make_video_pyav(
+ save_path,
+ audio_filepath=audio_filepath,
+ fps=fps,
+ output_filepath=step_output_filepath,
+ glob_pattern=f"*{image_file_ext}",
+ audio_offset=audio_offset,
+ audio_duration=audio_duration,
+ sr=44100,
+ )
+
+ return make_video_pyav(
+ save_path_root,
+ audio_filepath=audio_filepath,
+ fps=fps,
+ audio_offset=audio_start_sec,
+ audio_duration=sum(num_interpolation_steps) / fps,
+ output_filepath=output_filepath,
+ glob_pattern=f"**/*{image_file_ext}",
+ sr=44100,
+ )
+
+ def embed_text(
+ self, params: Union[Dict, FrozenDict], text: str, negative_prompt=None
+ ):
+ """Helper to embed some text"""
+ prompt_ids = self.prepare_inputs(text)
+ embed = self.text_encoder(prompt_ids, params=params)[0]
+ return embed
+
+ def init_noise(self, prng_seed, noise_shape, dtype):
+ """Helper to initialize noise"""
+ noise = jax.random.normal(prng_seed, shape=noise_shape, dtype=dtype)
+ return noise
+
+ # TODO: port this behavior to flax
+ # @classmethod
+ # def from_pretrained(cls, *args, tiled=False, **kwargs):
+ # """Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
+ # if tiled:
+
+ # def patch_conv(**patch):
+ # cls = nn.Conv2d
+ # init = cls.__init__
+
+ # def __init__(self, *args, **kwargs):
+ # return init(self, *args, **kwargs, **patch)
+
+ # cls.__init__ = __init__
+
+ # patch_conv(padding_mode="circular")
+
+ # pipeline = super().from_pretrained(*args, **kwargs)
+ # pipeline.tiled = tiled
+ # return pipeline
+
+
+# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
+# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
+# guidance_scale is a scalar, so it's broadcasted to all devices (hence `None`) without needing to be static.
+@partial(
+ jax.pmap,
+ in_axes=(None, 0, 0, 0, None, None, None, None, 0, 0, 0),
+ static_broadcasted_argnums=(0, 4, 5, 6),
+)
+def _p_generate(
+ pipe,
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ text_embeddings,
+):
+ return pipe._generate(
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ text_embeddings,
+ )
+
+
+@partial(jax.pmap, static_broadcasted_argnums=(0,))
+def _p_get_has_nsfw_concepts(pipe, features, params):
+ return pipe._get_has_nsfw_concepts(features, params)
+
+
+def unshard(x: jnp.ndarray):
+ # einops.rearrange(x, 'd b ... -> (d b) ...')
+ num_devices, batch_size = x.shape[:2]
+ rest = x.shape[2:]
+ return x.reshape(num_devices * batch_size, *rest)