Skip to content

Commit

Permalink
[SDE] Merge to unconditional model (open-mmlab#89)
Browse files Browse the repository at this point in the history
* up

* more

* uP

* make dummy test pass

* save intermediate

* p

* p

* finish

* finish

* finish
  • Loading branch information
patrickvonplaten authored Jul 18, 2022
1 parent b5c684f commit ba3c9a9
Show file tree
Hide file tree
Showing 9 changed files with 1,371 additions and 547 deletions.
15 changes: 14 additions & 1 deletion conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_output_pretrained_ldm():
# 2. DDPM

def get_model(model_id):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model = UNetUnconditionalModel.from_pretrained(model_id, ldm=True)

noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
Expand All @@ -123,3 +123,16 @@ def get_model(model_id):

# e.g.
get_model("fusing/ddpm-cifar10")

# 3. NCSNpp

# Repos to convert and port to google (part of https://github.com/yang-song/score_sde)
# - https://huggingface.co/fusing/ffhq_ncsnpp
# - https://huggingface.co/fusing/church_256-ncsnpp-ve
# - https://huggingface.co/fusing/celebahq_256-ncsnpp-ve
# - https://huggingface.co/fusing/bedroom_256-ncsnpp-ve
# - https://huggingface.co/fusing/ffhq_256-ncsnpp-ve

# tests to make sure to pass
# - test_score_sde_ve_pipeline (in PipelineTesterMixin)
# - test_output_pretrained_ve_mid, test_output_pretrained_ve_large (in NCSNppModelTests)
315 changes: 152 additions & 163 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,166 +6,6 @@
from torch import nn


# unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)

def forward(self, x, encoder_states=None):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)


# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""

def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels

self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor

if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)

self.proj = zero_module(nn.Conv1d(channels, channels, 1))

self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear

if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)

self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)

self.is_overwritten = False

def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)

self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias

proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data

self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)

self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data

self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data

def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True

b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)

qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)

if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)

scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)

a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)

h = self.proj(h)
h = h.reshape(b, c, *spatial)

result = x + h

result = result / self.rescale_output_factor

return result


class AttentionBlockNew(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
Expand Down Expand Up @@ -216,6 +56,7 @@ def forward(self, hidden_states):

# norm
hidden_states = self.group_norm(hidden_states)

hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

# proj to q, k, v
Expand All @@ -229,9 +70,9 @@ def forward(self, hidden_states):
value_states = self.transpose_for_scores(value_proj)

# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)

# compute attention output
context_states = torch.matmul(attention_probs, value_states)
Expand Down Expand Up @@ -263,6 +104,20 @@ def set_weight(self, attn_layer):

self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
elif hasattr(attn_layer, "NIN_0"):
self.query.weight.data = attn_layer.NIN_0.W.data.T
self.key.weight.data = attn_layer.NIN_1.W.data.T
self.value.weight.data = attn_layer.NIN_2.W.data.T

self.query.bias.data = attn_layer.NIN_0.b.data
self.key.bias.data = attn_layer.NIN_1.b.data
self.value.bias.data = attn_layer.NIN_2.b.data

self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
self.proj_attn.bias.data = attn_layer.NIN_3.b.data

self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
else:
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
Expand Down Expand Up @@ -452,3 +307,137 @@ def __init__(self, dim_in, dim_out):
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)


# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""

def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels

self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor

if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)

self.proj = zero_module(nn.Conv1d(channels, channels, 1))

self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear

if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)

self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)

self.is_overwritten = False

def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)

self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias

proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data

self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)

self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data

self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data

def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True

b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)

qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)

if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)

scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)

a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)

h = self.proj(h)
h = h.reshape(b, c, *spatial)

result = x + h

result = result / self.rescale_output_factor

return result
12 changes: 9 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@ def get_timestep_embedding(
return emb


# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""

def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

self.weight = self.W

def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
Loading

0 comments on commit ba3c9a9

Please sign in to comment.