diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index de3ce38c6..647d84f18 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -55,10 +55,50 @@ def __init__(self, in_channels, with_conv): stride=1, padding=1) - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) + def forward(self, x, split = 0, snum=8): + if split: + pad = 8 + w = len(x[0][0][0]) + s = int(w / snum) + t = torch.split(x, s, 3) + t1 = [] + for i in t: + t2 = torch.nn.functional.interpolate(i, scale_factor=2.0, mode="nearest") + t1.append(t2) + del t2 + del t + data = torch.cat(t1, 3) + del t1 + w = len(data[0][0][0]) + s = int(w / snum) + t = torch.split(data, s, 3) + t1 = [] + for i in t: + t2 = self.conv(i) + t1.append(t2) + del t2 + del t + x = torch.cat(t1, 3) + del t1 + + for i in range(1, snum): + ssize = [int(i * s) - 2, 4, w - int(i * s) - 2] + ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4] + t5 = torch.split(x, ssize, 3) + t6 = torch.split(data, ssize2, 3) + t7 = t6[1] + del t6 + t8 = self.conv(t7) + del t7 + t9 = torch.split(t8, [2, 4, 2], 3) + del t8 + x = torch.cat((t5[0], t9[1], t5[2]), 3) + del t5 + del t9 + else: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) return x @@ -123,7 +163,7 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, stride=1, padding=0) - def forward(self, x, temb): + def forward(self, x, temb, split = 0, snum=8): h1 = x h2 = self.norm1(h1) del h1 @@ -131,7 +171,42 @@ def forward(self, x, temb): h3 = nonlinearity(h2) del h2 - h4 = self.conv1(h3) + l = len(h3[0][0][0]) + l2 = len(h3[0][0]) + + # split only res >= 512 + # splitting lower res can overlap with the pad when width/snum <= 8 + if split and (l >= 512 or l2 >= 512): + pad = 8 + w = len(h3[0][0][0]) + s = int(w / snum) + t = torch.split(h3, s, 3) + t1 = [] + for i in t: + t2 = self.conv1(i) + t1.append(t2) + del t2 + del t + temp1 = torch.cat(t1, 3) + del t1 + for i in range(1, snum): + ssize = [int(i * s) - 2, 4, w - int(i * s) - 2] + ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4] + t5 = torch.split(temp1, ssize, 3) + t6 = torch.split(h3, ssize2, 3) + t7 = t6[1] + del t6 + t8 = self.conv1(t7) + del t7 + t9 = torch.split(t8, [2, 4, 2], 3) + del t8 + temp1 = torch.cat((t5[0], t9[1], t5[2]), 3) + del t5 + del t9 + h4 = temp1 + del temp1 + else: + h4 = self.conv1(h3) del h3 if temb is not None: @@ -146,7 +221,40 @@ def forward(self, x, temb): h7 = self.dropout(h6) del h6 - h8 = self.conv2(h7) + l = len(h7[0][0][0]) + l2 = len(h7[0][0]) + # same as above + if split and (l >= 512 or l2 >= 512): + pad = 8 + w = len(h7[0][0][0]) + s = int(w / snum) + t = torch.split(h7, s, 3) + t1 = [] + for i in t: + t2 = self.conv2(i) + t1.append(t2) + del t2 + del t + temp1 = torch.cat(t1, 3) + del t1 + for i in range(1, snum): + ssize = [int(i * s) - 2, 4, w - int(i * s) - 2] + ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4] + t5 = torch.split(temp1, ssize, 3) + t6 = torch.split(h7, ssize2, 3) + t7 = t6[1] + del t6 + t8 = self.conv2(t7) + del t7 + t9 = torch.split(t8, [2, 4, 2], 3) + del t8 + temp1 = torch.cat((t5[0], t9[1], t5[2]), 3) + del t5 + del t9 + h8 = temp1 + del temp1 + else: + h8 = self.conv2(h7) del h7 if self.in_channels != self.out_channels: @@ -609,19 +717,29 @@ def forward(self, z): gc.collect() torch.cuda.empty_cache() + # snum = how many times to slice the upscaling + # 32 should work as a general number for 512x512 and higher + # (64 % snum) must be 0 + split = 1 + snum = 32 + # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) + if split: + h = self.up[i_level].block[i_block](h, temb, split=1, snum=snum) + else: + h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: t = h h = self.up[i_level].attn[i_block](t) del t if i_level != 0: - t = h - h = self.up[i_level].upsample(t) - del t + if i_level == 1 and split: + h = self.up[i_level].upsample(h, split=1, snum=snum) + else: + h = self.up[i_level].upsample(h) # end if self.give_pre_end: @@ -633,7 +751,37 @@ def forward(self, z): h2 = nonlinearity(h1) del h1 - h = self.conv_out(h2) + if split: + pad = 8 + w = len(h2[0][0][0]) + s = int(w / snum) + t = torch.split(h2, s, 3) + t1 = [] + for i in t: + t2 = self.conv_out(i) + t1.append(t2) + del t2 + temp1 = torch.cat(t1, 3) + del t + del t1 + for i in range(1, snum): + ssize = [int(i * s) - 2, 4, w - int(i * s) - 2] + ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4] + t5 = torch.split(temp1, ssize, 3) + t6 = torch.split(h2, ssize2, 3) + t7 = t6[1] + del t6 + t8 = self.conv_out(t7) + del t7 + t9 = torch.split(t8, [2, 4, 2], 3) + del t8 + temp1 = torch.cat((t5[0], t9[1], t5[2]), 3) + del t5 + del t9 + h = temp1 + del temp1 + else: + h = self.conv_out(h2) del h2 if self.tanh_out: @@ -907,5 +1055,4 @@ def forward(self,x): if self.do_reshape: z = rearrange(z,'b c h w -> b (h w) c') - return z - + return z \ No newline at end of file