Skip to content

Commit

Permalink
[WIP] scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Oct 20, 2022
1 parent 1b1ee17 commit 9d33a33
Show file tree
Hide file tree
Showing 8 changed files with 687 additions and 61 deletions.
155 changes: 155 additions & 0 deletions orig_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import numpy as np


def log_1_min_a(a):
return torch.log(1 - a.exp() + 1e-40)

def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))

def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def log_categorical(log_x_start, log_prob):
return (log_x_start.exp() * log_prob).sum(dim=1)

def index_to_log_onehot(x, num_classes):
assert x.max().item() < num_classes, \
f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x

def log_onehot_to_index(log_x):
return log_x.argmax(1)

def alpha_schedule(time_step, N=100, att_1 = 0.99999, att_T = 0.000009, ctt_1 = 0.000009, ctt_T = 0.99999):
att = np.arange(0, time_step)/(time_step-1)*(att_T - att_1) + att_1
att = np.concatenate(([1], att))
at = att[1:]/att[:-1]
ctt = np.arange(0, time_step)/(time_step-1)*(ctt_T - ctt_1) + ctt_1
ctt = np.concatenate(([0], ctt))
one_minus_ctt = 1 - ctt
one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
ct = 1-one_minus_ct
bt = (1-at-ct)/N
att = np.concatenate((att[1:], [1]))
ctt = np.concatenate((ctt[1:], [0]))
btt = (1-att-ctt)/N
return at, bt, ct, att, btt, ctt


class OrigScheduler:
def __init__(self, *, num_classes, content_seq_len, num_timesteps=100):
self.num_timesteps = num_timesteps
self.num_classes = num_classes
self.content_seq_len = content_seq_len

at, bt, ct, att, btt, ctt = alpha_schedule(self.num_timesteps, N=self.num_classes-1)

at = torch.tensor(at.astype('float64'))
bt = torch.tensor(bt.astype('float64'))
ct = torch.tensor(ct.astype('float64'))
log_at = torch.log(at)
log_bt = torch.log(bt)
log_ct = torch.log(ct)
att = torch.tensor(att.astype('float64'))
btt = torch.tensor(btt.astype('float64'))
ctt = torch.tensor(ctt.astype('float64'))
log_cumprod_at = torch.log(att)
log_cumprod_bt = torch.log(btt)
log_cumprod_ct = torch.log(ctt)

log_1_min_ct = log_1_min_a(log_ct)
log_1_min_cumprod_ct = log_1_min_a(log_cumprod_ct)

assert log_add_exp(log_ct, log_1_min_ct).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_ct, log_1_min_cumprod_ct).abs().sum().item() < 1.e-5

# Convert to float32 and register buffers.
self.log_at = log_at.float()
self.log_bt = log_bt.float()
self.log_ct = log_ct.float()
self.log_cumprod_at = log_cumprod_at.float()
self.log_cumprod_bt = log_cumprod_bt.float()
self.log_cumprod_ct = log_cumprod_ct.float()
self.log_1_min_ct = log_1_min_ct.float()
self.log_1_min_cumprod_ct = log_1_min_cumprod_ct.float()



def q_posterior(self, log_x_start, log_x_t, t): # p_theta(xt_1|xt) = sum(q(xt-1|xt,x0')*p(x0'))
# notice that log_x_t is onehot
assert t.min().item() >= 0 and t.max().item() < self.num_timesteps
batch_size = log_x_start.size()[0]
onehot_x_t = log_onehot_to_index(log_x_t)
mask = (onehot_x_t == self.num_classes-1).unsqueeze(1)
log_one_vector = torch.zeros(batch_size, 1, 1).type_as(log_x_t)
log_zero_vector = torch.log(log_one_vector+1.0e-30).expand(-1, -1, self.content_seq_len)

log_qt = self.q_pred(log_x_t, t) # q(xt|x0)
# log_qt = torch.cat((log_qt[:,:-1,:], log_zero_vector), dim=1)
log_qt = log_qt[:,:-1,:]
log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
ct_cumprod_vector = log_cumprod_ct.expand(-1, self.num_classes-1, -1)
# ct_cumprod_vector = torch.cat((ct_cumprod_vector, log_one_vector), dim=1)
log_qt = (~mask)*log_qt + mask*ct_cumprod_vector


log_qt_one_timestep = self.q_pred_one_timestep(log_x_t, t) # q(xt|xt_1)
log_qt_one_timestep = torch.cat((log_qt_one_timestep[:,:-1,:], log_zero_vector), dim=1)
log_ct = extract(self.log_ct, t, log_x_start.shape) # ct
ct_vector = log_ct.expand(-1, self.num_classes-1, -1)
ct_vector = torch.cat((ct_vector, log_one_vector), dim=1)
log_qt_one_timestep = (~mask)*log_qt_one_timestep + mask*ct_vector

# log_x_start = torch.cat((log_x_start, log_zero_vector), dim=1)
# q = log_x_start - log_qt
q = log_x_start[:,:-1,:] - log_qt
q = torch.cat((q, log_zero_vector), dim=1)
q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
q = q - q_log_sum_exp
log_EV_xtmin_given_xt_given_xstart = self.q_pred(q, t-1) + log_qt_one_timestep + q_log_sum_exp
return torch.clamp(log_EV_xtmin_given_xt_given_xstart, -70, 0)


def q_pred_one_timestep(self, log_x_t, t): # q(xt|xt_1)
log_at = extract(self.log_at, t, log_x_t.shape) # at
log_bt = extract(self.log_bt, t, log_x_t.shape) # bt
log_ct = extract(self.log_ct, t, log_x_t.shape) # ct
log_1_min_ct = extract(self.log_1_min_ct, t, log_x_t.shape) # 1-ct

log_probs = torch.cat(
[
log_add_exp(log_x_t[:,:-1,:]+log_at, log_bt),
log_add_exp(log_x_t[:, -1:, :] + log_1_min_ct, log_ct)
],
dim=1
)

return log_probs

def q_pred(self, log_x_start, t): # q(xt|x0)
# log_x_start can be onehot or not
t = (t + (self.num_timesteps + 1))%(self.num_timesteps + 1)
log_cumprod_at = extract(self.log_cumprod_at, t, log_x_start.shape) # at~
log_cumprod_bt = extract(self.log_cumprod_bt, t, log_x_start.shape) # bt~
log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
log_1_min_cumprod_ct = extract(self.log_1_min_cumprod_ct, t, log_x_start.shape) # 1-ct~


log_probs = torch.cat(
[
log_add_exp(log_x_start[:,:-1,:]+log_cumprod_at, log_cumprod_bt),
log_add_exp(log_x_start[:,-1:,:]+log_1_min_cumprod_ct, log_cumprod_ct)
],
dim=1
)

return log_probs
23 changes: 20 additions & 3 deletions scripts/convert_vq_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@

import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQModel
from diffusers import VQModel, VQDiffusionPipeline, VQDiffusionScheduler
from diffusers.models.vq_diffusion_attention import VQDiffusionTransformer
from diffusers.pipelines import VQDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader

Expand Down Expand Up @@ -492,7 +491,12 @@ def transformer_model_from_original_config(

depth = original_transformer_config["n_layer"]
context_dim = original_transformer_config["condition_dim"]

num_embed = original_content_embedding_config["num_embed"]
# the number of embeddings in the transformer includes the mask embedding.
# the content embedding (the vqvae) does not include the mask embedding.
num_embed = num_embed + 1

height = original_transformer_config["content_spatial_size"][0]
width = original_transformer_config["content_spatial_size"][1]
dropout = original_transformer_config["resid_pdrop"]
Expand Down Expand Up @@ -846,10 +850,23 @@ def read_config_file(filename):

# done text encoder

# scheduler

scheduler_model = VQDiffusionScheduler(
# the scheduler has the same number of embeddings as the transformer
num_embed=transformer_model.num_embed
)

# done scheduler

print(f"saving VQ diffusion model, path: {args.dump_path}")

pipe = VQDiffusionPipeline(
vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline, VQDiffusionPipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
VQDiffusionScheduler
)
from .training_utils import EMAModel
else:
Expand Down
49 changes: 0 additions & 49 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,52 +115,3 @@ def forward(self, x):
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out


# TODO(will) - document this. check if throwing errors internally is appropriate
class DalleMaskImageEmbedding(nn.Module):
def __init__(
self,
num_embed,
height,
width,
embed_dim,
):
super().__init__()

self.height = height
self.width = width
# TODO(will) add docs on why this is incremented by 1. (Has to do with mask?)
self.num_embed = num_embed + 1
self.embed_dim = embed_dim

self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)

def forward(self, index):
assert index.dim() == 2 # B x L
try:
index[index < 0] = 0
emb = self.emb(index)
except:
raise RuntimeError(
"IndexError: index out of range in self, max index {}, num embed {}".format(
index.max(), self.num_embed
)
)

# add col and row embedding
if emb.shape[1] > 0:
height_emb = self.height_emb(
torch.arange(self.height, device=index.device).view(1, self.height)
).unsqueeze(
2
) # 1 x H x D -> 1 x H x 1 x D
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)).unsqueeze(
1
) # 1 x W x D -> 1 x 1 x W x D
pos_emb = (height_emb + width_emb).view(1, self.height * self.width, -1) # 1 x H x W x D -> 1 x L xD
emb = emb + pos_emb[:, : emb.shape[1], :]

return emb
Loading

0 comments on commit 9d33a33

Please sign in to comment.