From 10db192cc4be66b3cebbdaa48a1806807578b56f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Sep 2022 09:26:10 -0400 Subject: [PATCH] changes to dogettx optimizations to run on m1 * Author @any-winter-4079 * Author @dogettx Thanks to many individuals who contributed time and hardware to benchmarking and debugging these changes. --- ldm/generate.py | 27 +- ldm/modules/attention.py | 94 ++- ldm/modules/diffusionmodules/model.py | 947 +++++++++++--------------- 3 files changed, 488 insertions(+), 580 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index ce54b04ebab..3a81087d5f5 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -35,17 +35,7 @@ from ldm.generate import Generate # Create an object with default values -gr = Generate(model = // models/ldm/stable-diffusion-v1/model.ckpt - config = // configs/stable-diffusion/v1-inference.yaml - iterations = // how many times to run the sampling (1) - steps = // 50 - seed = // current system time - sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms - grid = // false - width = // image width, multiple of 64 (512) - height = // image height, multiple of 64 (512) - cfg_scale = // condition-free guidance scale (7.5) - ) +gr = Generate() # do the slow model initialization gr.load_model() @@ -86,6 +76,21 @@ Note that the old txt2img() and img2img() calls are deprecated but will still work. + +The full list of arguments to Generate() are: +gr = Generate( + weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') + config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml') + iterations = // how many times to run the sampling (1) + steps = // 50 + seed = // current system time + sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms + grid = // false + width = // image width, multiple of 64 (512) + height = // image height, multiple of 64 (512) + cfg_scale = // condition-free guidance scale (7.5) + ) + """ diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index c2f905688c8..0dd957b4074 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,20 +1,20 @@ -import math from inspect import isfunction - +import math import torch import torch.nn.functional as F -from einops import rearrange, repeat from torch import nn, einsum +from einops import rearrange, repeat from ldm.modules.diffusionmodules.util import checkpoint +import psutil def exists(val): return val is not None def uniq(arr): - return {el: True for el in arr}.keys() + return{el: True for el in arr}.keys() def default(val, d): @@ -83,14 +83,14 @@ def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) - k = k.softmax(dim=-1) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -132,12 +132,12 @@ def forward(self, x): v = self.v(h_) # compute attention - b, c, h, w = q.shape + b,c,h,w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) - w_ = w_ * (int(c) ** (-0.5)) + w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values @@ -147,7 +147,7 @@ def forward(self, x): h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) - return x + h_ + return x+h_ class CrossAttention(nn.Module): @@ -171,41 +171,66 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - device_type = x.device.type + k_in = self.to_k(context) + v_in = self.to_v(context) + device_type = 'mps' if x.device.type == 'mps' else 'cuda' del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40) - del q, k + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - del mask + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if device_type == 'mps': #special case for M1 - disable neonsecret optimization - sim = sim.softmax(dim=-1) + if device_type == 'mps': + mem_free_total = psutil.virtual_memory().available else: - sim[4:] = sim[4:].softmax(dim=-1) - sim[:4] = sim[:4].softmax(dim=-1) + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - sim = einsum('b i j, b j d -> b i d', sim, v) - sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) - return self.to_out(sim) + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, - dropout=dropout) # is a self-attention + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none @@ -233,7 +258,6 @@ class SpatialTransformer(nn.Module): Then apply standard transformer action. Finally, reshape to image """ - def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None): super().__init__() @@ -249,7 +273,7 @@ def __init__(self, in_channels, n_heads, d_head, self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth)] + for d in range(depth)] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index cd79e375657..a6cefc82ada 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -8,6 +9,7 @@ from ldm.util import instantiate_from_config from ldm.modules.attention import LinearAttention +import psutil def get_timestep_embedding(timesteps, embedding_dim): """ @@ -26,19 +28,17 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb def nonlinearity(x): # swish - return x * torch.sigmoid(x) + return x*torch.sigmoid(x) def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): @@ -46,14 +46,14 @@ def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): - x = torch.nn.functional.interpolate( - x, scale_factor=2.0, mode='nearest' - ) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x @@ -65,14 +65,16 @@ def __init__(self, in_channels, with_conv): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) def forward(self, x): if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -80,15 +82,8 @@ def forward(self, x): class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -96,47 +91,60 @@ def __init__( self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - ) + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - ) + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -144,12 +152,10 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x + h - + return x + h8 class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" - def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) @@ -160,87 +166,116 @@ def __init__(self, in_channels): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm( - v, w_ - ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -def make_attn(in_channels, attn_type='vanilla'): - assert attn_type in [ - 'vanilla', - 'linear', - 'none', - ], f'attn_type {attn_type} unknown' - print( - f"making attention of type '{attn_type}' with {in_channels} in_channels" - ) - if attn_type == 'vanilla': + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + device_type = 'mps' if q.device.type == 'mps' else 'cuda' + + if device_type == 'mps': + mem_free_total = psutil.virtual_memory().available + else: + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": return AttnBlock(in_channels) - elif attn_type == 'none': + elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type='vanilla', - ): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: - attn_type = 'linear' + if use_linear_attn: attn_type = "linear" self.ch = ch - self.temb_ch = self.ch * 4 + self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -250,80 +285,70 @@ def __init__( if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) + in_ch_mult = (1,)+tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -333,16 +358,18 @@ def __init__( if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution + #assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -364,7 +391,7 @@ def forward(self, x, t=None, context=None): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -375,10 +402,9 @@ def forward(self, x, t=None, context=None): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): + for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) + torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -395,27 +421,12 @@ def get_last_layer(self): class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type='vanilla', - **ignore_kwargs, - ): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): super().__init__() - if use_linear_attn: - attn_type = 'linear' + if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -424,64 +435,56 @@ def __init__( self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) + in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): # timestep embedding @@ -495,7 +498,7 @@ def forward(self, x): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -512,28 +515,12 @@ def forward(self, x): class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type='vanilla', - **ignorekwargs, - ): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): super().__init__() - if use_linear_attn: - attn_type = 'linear' + if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -544,52 +531,43 @@ def __init__( self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print( - 'Working with z of shape {} = {} dimensions.'.format( - self.z_shape, np.prod(self.z_shape) - ) - ) + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -599,87 +577,103 @@ def __init__( if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] + #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + device_type = 'mps' if h.device.type == 'mps' else 'cuda' + gc.collect() + if device_type == 'cuda': + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): + for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h class SimpleDecoder(nn.Module): def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList( - [ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock( - in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True), - ] - ) + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) # end self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): for i, layer in enumerate(self.model): - if i in [1, 2, 3]: + if i in [1,2,3]: x = layer(x, None) else: x = layer(x) @@ -691,16 +685,8 @@ def forward(self, x): class UpsampleDecoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ch, - num_res_blocks, - resolution, - ch_mult=(2, 2), - dropout=0.0, - ): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): super().__init__() # upsampling self.temb_ch = 0 @@ -714,14 +700,10 @@ def __init__( res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - res_block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -730,9 +712,11 @@ def __init__( # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): # upsampling @@ -749,56 +733,35 @@ def forward(self, x): class LatentRescaler(nn.Module): - def __init__( - self, factor, in_channels, mid_channels, out_channels, depth=2 - ): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor - self.conv_in = nn.Conv2d( - in_channels, mid_channels, kernel_size=3, stride=1, padding=1 - ) - self.res_block1 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - - self.conv_out = nn.Conv2d( - mid_channels, - out_channels, - kernel_size=1, - ) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate( - x, - size=( - int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)), - ), - ) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -807,42 +770,17 @@ def forward(self, x): class MergedRescaleEncoder(nn.Module): - def __init__( - self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1, - ): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder( - in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth, - ) + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) def forward(self, x): x = self.encoder(x) @@ -851,41 +789,15 @@ def forward(self, x): class MergedRescaleDecoder(nn.Module): - def __init__( - self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1, - ): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): super().__init__() - tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder( - out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth, - ) + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) def forward(self, x): x = self.rescaler(x) @@ -894,32 +806,17 @@ def forward(self, x): class Upsampler(nn.Module): - def __init__( - self, in_size, out_size, in_channels, out_channels, ch_mult=2 - ): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size - num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1.0 + (out_size % in_size) - print( - f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}' - ) - self.rescaler = LatentRescaler( - factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels, - ) - self.decoder = Decoder( - out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)], - ) + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) def forward(self, x): x = self.rescaler(x) @@ -928,55 +825,42 @@ def forward(self, x): class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode='bilinear'): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: - print( - f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode' - ) + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=4, stride=2, padding=1 - ) + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) def forward(self, x, scale_factor=1.0): - if scale_factor == 1.0: + if scale_factor==1.0: return x else: - x = torch.nn.functional.interpolate( - x, - mode=self.mode, - align_corners=False, - scale_factor=scale_factor, - ) + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) return x - class FirstStagePostProcessor(nn.Module): - def __init__( - self, - ch_mult: list, - in_channels, - pretrained_model: nn.Module = None, - reshape=False, - n_channels=None, - dropout=0.0, - pretrained_config=None, - ): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): super().__init__() if pretrained_config is None: - assert ( - pretrained_model is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' self.pretrained_model = pretrained_model else: - assert ( - pretrained_config is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' self.instantiate_pretrained(pretrained_config) self.do_reshape = reshape @@ -984,28 +868,22 @@ def __init__( if n_channels is None: n_channels = self.pretrained_model.encoder.ch - self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) - self.proj = nn.Conv2d( - in_channels, n_channels, kernel_size=3, stride=1, padding=1 - ) + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) blocks = [] downs = [] ch_in = n_channels for m in ch_mult: - blocks.append( - ResnetBlock( - in_channels=ch_in, - out_channels=m * n_channels, - dropout=dropout, - ) - ) + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) ch_in = m * n_channels downs.append(Downsample(ch_in, with_conv=False)) self.model = nn.ModuleList(blocks) self.downsampler = nn.ModuleList(downs) + def instantiate_pretrained(self, config): model = instantiate_from_config(config) self.pretrained_model = model.eval() @@ -1013,23 +891,24 @@ def instantiate_pretrained(self, config): for param in self.pretrained_model.parameters(): param.requires_grad = False + @torch.no_grad() - def encode_with_pretrained(self, x): + def encode_with_pretrained(self,x): c = self.pretrained_model.encode(x) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() - return c + return c - def forward(self, x): + def forward(self,x): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) z = nonlinearity(z) - for submodel, downmodel in zip(self.model, self.downsampler): - z = submodel(z, temb=None) + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) z = downmodel(z) if self.do_reshape: - z = rearrange(z, 'b c h w -> b (h w) c') + z = rearrange(z,'b c h w -> b (h w) c') return z