Skip to content

Commit

Permalink
Don't reuse nn.ReLU modules in CLIP ResNet (facebookresearch#33)
Browse files Browse the repository at this point in the history
Summary:
Reusing the same ReLU module for multiple layers can make it more difficult for researchers as for example PyTorch's hook system won't work properly on the reused layer modules. I ran into this issue while building and testing interpretability tools on the CLIP models.

This PR doesn't change how any of the models work. It merely makes it possible to access and research each ReLU layer separately. Let me know if I need to make any changes before it can be merged!

Pull Request resolved: facebookresearch#33

Reviewed By: ankitade

Differential Revision: D36110555

Pulled By: ebsmothers

fbshipit-source-id: 992ae5bb53dd1fe83e793f55cc7258cc06516a74
  • Loading branch information
ProGamerGov authored and facebook-github-bot committed May 3, 2022
1 parent 292219e commit d216331
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions torchmultimodal/modules/encoders/clip_resnet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ def __init__(self, inplanes: int, planes: int, stride: int = 1):
# an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)

self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

self.conv3 = nn.Conv2d(planes, planes * EXPANSION, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * EXPANSION)
self.relu3 = nn.ReLU(inplace=True)

self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride

Expand Down Expand Up @@ -62,16 +64,16 @@ def __init__(self, inplanes: int, planes: int, stride: int = 1):
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x

out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)
out = self.relu3(out)
return out


Expand Down Expand Up @@ -161,14 +163,16 @@ def __init__(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)

# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
Expand Down Expand Up @@ -215,12 +219,9 @@ def initialize_parameters(self):

def forward(self, x: torch.Tensor) -> torch.Tensor:
def stem(x):
for conv, bn in [
(self.conv1, self.bn1),
(self.conv2, self.bn2),
(self.conv3, self.bn3),
]:
x = self.relu(bn(conv(x)))
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x

Expand Down

0 comments on commit d216331

Please sign in to comment.