Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split the tensors in upscaling #16

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 161 additions & 14 deletions ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,50 @@ def __init__(self, in_channels, with_conv):
stride=1,
padding=1)

def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
def forward(self, x, split = 0, snum=8):
if split:
pad = 8
w = len(x[0][0][0])
s = int(w / snum)
t = torch.split(x, s, 3)
t1 = []
for i in t:
t2 = torch.nn.functional.interpolate(i, scale_factor=2.0, mode="nearest")
t1.append(t2)
del t2
del t
data = torch.cat(t1, 3)
del t1
w = len(data[0][0][0])
s = int(w / snum)
t = torch.split(data, s, 3)
t1 = []
for i in t:
t2 = self.conv(i)
t1.append(t2)
del t2
del t
x = torch.cat(t1, 3)
del t1

for i in range(1, snum):
ssize = [int(i * s) - 2, 4, w - int(i * s) - 2]
ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4]
t5 = torch.split(x, ssize, 3)
t6 = torch.split(data, ssize2, 3)
t7 = t6[1]
del t6
t8 = self.conv(t7)
del t7
t9 = torch.split(t8, [2, 4, 2], 3)
del t8
x = torch.cat((t5[0], t9[1], t5[2]), 3)
del t5
del t9
else:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x


Expand Down Expand Up @@ -123,15 +163,50 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
stride=1,
padding=0)

def forward(self, x, temb):
def forward(self, x, temb, split = 0, snum=8):
h1 = x
h2 = self.norm1(h1)
del h1

h3 = nonlinearity(h2)
del h2

h4 = self.conv1(h3)
l = len(h3[0][0][0])
l2 = len(h3[0][0])

# split only res >= 512
# splitting lower res can overlap with the pad when width/snum <= 8
if split and (l >= 512 or l2 >= 512):
pad = 8
w = len(h3[0][0][0])
s = int(w / snum)
t = torch.split(h3, s, 3)
t1 = []
for i in t:
t2 = self.conv1(i)
t1.append(t2)
del t2
del t
temp1 = torch.cat(t1, 3)
del t1
for i in range(1, snum):
ssize = [int(i * s) - 2, 4, w - int(i * s) - 2]
ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4]
t5 = torch.split(temp1, ssize, 3)
t6 = torch.split(h3, ssize2, 3)
t7 = t6[1]
del t6
t8 = self.conv1(t7)
del t7
t9 = torch.split(t8, [2, 4, 2], 3)
del t8
temp1 = torch.cat((t5[0], t9[1], t5[2]), 3)
del t5
del t9
h4 = temp1
del temp1
else:
h4 = self.conv1(h3)
del h3

if temb is not None:
Expand All @@ -146,7 +221,40 @@ def forward(self, x, temb):
h7 = self.dropout(h6)
del h6

h8 = self.conv2(h7)
l = len(h7[0][0][0])
l2 = len(h7[0][0])
# same as above
if split and (l >= 512 or l2 >= 512):
pad = 8
w = len(h7[0][0][0])
s = int(w / snum)
t = torch.split(h7, s, 3)
t1 = []
for i in t:
t2 = self.conv2(i)
t1.append(t2)
del t2
del t
temp1 = torch.cat(t1, 3)
del t1
for i in range(1, snum):
ssize = [int(i * s) - 2, 4, w - int(i * s) - 2]
ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4]
t5 = torch.split(temp1, ssize, 3)
t6 = torch.split(h7, ssize2, 3)
t7 = t6[1]
del t6
t8 = self.conv2(t7)
del t7
t9 = torch.split(t8, [2, 4, 2], 3)
del t8
temp1 = torch.cat((t5[0], t9[1], t5[2]), 3)
del t5
del t9
h8 = temp1
del temp1
else:
h8 = self.conv2(h7)
del h7

if self.in_channels != self.out_channels:
Expand Down Expand Up @@ -609,19 +717,29 @@ def forward(self, z):
gc.collect()
torch.cuda.empty_cache()

# snum = how many times to slice the upscaling
# 32 should work as a general number for 512x512 and higher
# (64 % snum) must be 0
split = 1
snum = 32

# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if split:
h = self.up[i_level].block[i_block](h, temb, split=1, snum=snum)
else:
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
t = h
h = self.up[i_level].attn[i_block](t)
del t

if i_level != 0:
t = h
h = self.up[i_level].upsample(t)
del t
if i_level == 1 and split:
h = self.up[i_level].upsample(h, split=1, snum=snum)
else:
h = self.up[i_level].upsample(h)

# end
if self.give_pre_end:
Expand All @@ -633,7 +751,37 @@ def forward(self, z):
h2 = nonlinearity(h1)
del h1

h = self.conv_out(h2)
if split:
pad = 8
w = len(h2[0][0][0])
s = int(w / snum)
t = torch.split(h2, s, 3)
t1 = []
for i in t:
t2 = self.conv_out(i)
t1.append(t2)
del t2
temp1 = torch.cat(t1, 3)
del t
del t1
for i in range(1, snum):
ssize = [int(i * s) - 2, 4, w - int(i * s) - 2]
ssize2 = [int(i * s) - 4, 8, w - int(i * s) - 4]
t5 = torch.split(temp1, ssize, 3)
t6 = torch.split(h2, ssize2, 3)
t7 = t6[1]
del t6
t8 = self.conv_out(t7)
del t7
t9 = torch.split(t8, [2, 4, 2], 3)
del t8
temp1 = torch.cat((t5[0], t9[1], t5[2]), 3)
del t5
del t9
h = temp1
del temp1
else:
h = self.conv_out(h2)
del h2

if self.tanh_out:
Expand Down Expand Up @@ -907,5 +1055,4 @@ def forward(self,x):

if self.do_reshape:
z = rearrange(z,'b c h w -> b (h w) c')
return z

return z