diff --git a/torchmultimodal/modules/encoders/clip_resnet_encoder.py b/torchmultimodal/modules/encoders/clip_resnet_encoder.py index 603aca92..4ce4f7de 100644 --- a/torchmultimodal/modules/encoders/clip_resnet_encoder.py +++ b/torchmultimodal/modules/encoders/clip_resnet_encoder.py @@ -24,16 +24,18 @@ def __init__(self, inplanes: int, planes: int, stride: int = 1): # an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * EXPANSION, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * EXPANSION) + self.relu3 = nn.ReLU(inplace=True) - self.relu = nn.ReLU(inplace=True) self.downsample = None self.stride = stride @@ -62,8 +64,8 @@ def __init__(self, inplanes: int, planes: int, stride: int = 1): def forward(self, x: torch.Tensor) -> torch.Tensor: identity = x - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) @@ -71,7 +73,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: identity = self.downsample(x) out += identity - out = self.relu(out) + out = self.relu3(out) return out @@ -161,14 +163,16 @@ def __init__( 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( width // 2, width // 2, kernel_size=3, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) # residual layers self._inplanes = width # this is a *mutable* variable used during construction @@ -215,12 +219,9 @@ def initialize_parameters(self): def forward(self, x: torch.Tensor) -> torch.Tensor: def stem(x): - for conv, bn in [ - (self.conv1, self.bn1), - (self.conv2, self.bn2), - (self.conv3, self.bn3), - ]: - x = self.relu(bn(conv(x))) + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) x = self.avgpool(x) return x