Skip to content

Commit

Permalink
Renamed x -> hidden_states in resnet.py (open-mmlab#676)
Browse files Browse the repository at this point in the history
renamed x to hidden_states
  • Loading branch information
daspartho committed Sep 29, 2022
1 parent 3dacbb9 commit a7058f4
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
else:
self.Conv2d_0 = conv

def forward(self, x):
assert x.shape[1] == self.channels
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
return self.conv(hidden_states)

x = F.interpolate(x, scale_factor=2.0, mode="nearest")
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
x = self.conv(x)
hidden_states = self.conv(hidden_states)
else:
x = self.Conv2d_0(x)
hidden_states = self.Conv2d_0(hidden_states)

return x
return hidden_states


class Downsample2D(nn.Module):
Expand Down Expand Up @@ -84,16 +84,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
else:
self.conv = conv

def forward(self, x):
assert x.shape[1] == self.channels
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)

assert x.shape[1] == self.channels
x = self.conv(x)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)

return x
return hidden_states


class FirUpsample2D(nn.Module):
Expand Down Expand Up @@ -174,12 +174,12 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):

return x

def forward(self, x):
def forward(self, hidden_states):
if self.use_conv:
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)

return height

Expand Down Expand Up @@ -236,14 +236,14 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):

return x

def forward(self, x):
def forward(self, hidden_states):
if self.use_conv:
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)

return x
return hidden_states


class ResnetBlock2D(nn.Module):
Expand Down

0 comments on commit a7058f4

Please sign in to comment.