Skip to content

Commit

Permalink
[WIP] scheduler scaffolding
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Oct 11, 2022
1 parent 1b1ee17 commit 496524c
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 1 deletion.
9 changes: 8 additions & 1 deletion scripts/convert_vq_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from diffusers import VQModel
from diffusers.models.vq_diffusion_attention import VQDiffusionTransformer
from diffusers.pipelines import VQDiffusionPipeline
from diffusers.schedulers import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader

Expand Down Expand Up @@ -846,10 +847,16 @@ def read_config_file(filename):

# done text encoder

scheduler_model = VQDiffusionScheduler()

print(f"saving VQ diffusion model, path: {args.dump_path}")

pipe = VQDiffusionPipeline(
vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)

Expand Down
129 changes: 129 additions & 0 deletions src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
from dataclasses import dataclass
from typing import Callable, List, Optional, Union

import numpy as np
import torch

import PIL
from diffusers import VQDiffusionTransformer, VQModel
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer

from ...pipeline_utils import DiffusionPipeline
from ...utils import BaseOutput, logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@dataclass
class VQDiffusionPipelineOutput(BaseOutput):
"""
Args:
Output class for VQ Diffusion pipelines.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""

images: Union[List[PIL.Image.Image], np.ndarray]


# This class is a placeholder and does not have the full VQ-diffusion pipeline built out yet
Expand All @@ -21,11 +46,115 @@ def __init__(
transformer: VQDiffusionTransformer,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
scheduler: VQDiffusionScheduler,
):
super().__init__()

self.register_modules(
vqvae=vqvae,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
)

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: int = 256,
width: int = 256,
num_inference_steps: int = 100,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# get the initial random noise unless the user supplied it

# TODO HERE - what's the input shape?
latents_shape = TODO # (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
# all masked?
latents = TODO # torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(self.device)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

timesteps_tensor = self.scheduler.timesteps.to(self.device)

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# predict the un-noised image
x0_pred = TODO # self.transformer(latents, t, encoder_hidden_states=text_embeddings).sample

# compute the previous noisy sample x_t -> x_t-1
latents = TODO # self.scheduler.step(x0_pred, t, latents).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

image = self.vqvae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)

if not return_dict:
return image

return VQDiffusionPipelineOutput(images=image)
1 change: 1 addition & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
else:
from ..utils.dummy_pt_objects import * # noqa F403

Expand Down
129 changes: 129 additions & 0 deletions src/diffusers/schedulers/scheduling_vq_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from dataclasses import dataclass
from typing import Tuple, Union

import torch
import torch.nn.functional as F

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))


def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def index_to_log_onehot(x, num_classes):
assert x.max().item() < num_classes, f"Error: {x.max().item()} >= {num_classes}"
x_onehot = F.one_hot(x, num_classes)
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x


def log_onehot_to_index(log_x):
return log_x.argmax(1)


@dataclass
class VQDiffusionSchedulerOutput(BaseOutput):
...


class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(self):
...

def set_timestamps(self):
...

def step(self, out, t, log_x) -> Union[VQDiffusionSchedulerOutput, Tuple]:
log_x_recon = F.log_softmax(out.double(), dim=1).float()
batch_size = TODO

zero_vector = torch.zeros(batch_size, 1, self.content_seq_len) - 70
log_x_recon = torch.cat((log_x_recon, zero_vector), dim=1)
log_x_recon = torch.clamp(log_x_recon, -70, 0)

log_model_pred = self.q_posterior(log_x_start=log_x_recon, log_x_t=log_x, t=t)

out = self.log_sample_categorical(log_model_pred)

return out

def q_posterior(self, log_x_start, log_x_t, t): # p_theta(xt_1|xt) = sum(q(xt-1|xt,x0')*p(x0'))
# notice that log_x_t is onehot
assert t.min().item() >= 0 and t.max().item() < self.num_timesteps
batch_size = log_x_start.size()[0]
onehot_x_t = log_onehot_to_index(log_x_t)
mask = (onehot_x_t == self.num_classes - 1).unsqueeze(1)
log_one_vector = torch.zeros(batch_size, 1, 1).type_as(log_x_t)
log_zero_vector = torch.log(log_one_vector + 1.0e-30).expand(-1, -1, self.content_seq_len)

log_qt = self.q_pred(log_x_t, t) # q(xt|x0)
log_qt = log_qt[:, :-1, :]
log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
ct_cumprod_vector = log_cumprod_ct.expand(-1, self.num_classes - 1, -1)
log_qt = (~mask) * log_qt + mask * ct_cumprod_vector

log_qt_one_timestep = self.q_pred_one_timestep(log_x_t, t) # q(xt|xt_1)
log_qt_one_timestep = torch.cat((log_qt_one_timestep[:, :-1, :], log_zero_vector), dim=1)
log_ct = extract(self.log_ct, t, log_x_start.shape) # ct
ct_vector = log_ct.expand(-1, self.num_classes - 1, -1)
ct_vector = torch.cat((ct_vector, log_one_vector), dim=1)
log_qt_one_timestep = (~mask) * log_qt_one_timestep + mask * ct_vector

q = log_x_start[:, :-1, :] - log_qt
q = torch.cat((q, log_zero_vector), dim=1)
q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
q = q - q_log_sum_exp
log_EV_xtmin_given_xt_given_xstart = self.q_pred(q, t - 1) + log_qt_one_timestep + q_log_sum_exp
return torch.clamp(log_EV_xtmin_given_xt_given_xstart, -70, 0)

def q_pred(self, log_x_start, t): # q(xt|x0)
# log_x_start can be onehot or not
t = (t + (self.num_timesteps + 1)) % (self.num_timesteps + 1)
log_cumprod_at = extract(self.log_cumprod_at, t, log_x_start.shape) # at~
log_cumprod_bt = extract(self.log_cumprod_bt, t, log_x_start.shape) # bt~
log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
log_1_min_cumprod_ct = extract(self.log_1_min_cumprod_ct, t, log_x_start.shape) # 1-ct~

log_probs = torch.cat(
[
log_add_exp(log_x_start[:, :-1, :] + log_cumprod_at, log_cumprod_bt),
log_add_exp(log_x_start[:, -1:, :] + log_1_min_cumprod_ct, log_cumprod_ct),
],
dim=1,
)

return log_probs

def q_pred_one_timestep(self, log_x_t, t): # q(xt|xt_1)
log_at = extract(self.log_at, t, log_x_t.shape) # at
log_bt = extract(self.log_bt, t, log_x_t.shape) # bt
log_ct = extract(self.log_ct, t, log_x_t.shape) # ct
log_1_min_ct = extract(self.log_1_min_ct, t, log_x_t.shape) # 1-ct

log_probs = torch.cat(
[log_add_exp(log_x_t[:, :-1, :] + log_at, log_bt), log_add_exp(log_x_t[:, -1:, :] + log_1_min_ct, log_ct)],
dim=1,
)

return log_probs

# use gumbel to sample onehot vector from log probability
def log_sample_categorical(self, logits):
uniform = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample = (gumbel_noise + logits).argmax(dim=1)
log_sample = index_to_log_onehot(sample, self.num_classes)
return log_sample

0 comments on commit 496524c

Please sign in to comment.