diff --git a/MODELS/bam.py b/MODELS/bam.py index cbda3d0..b2204b9 100644 --- a/MODELS/bam.py +++ b/MODELS/bam.py @@ -21,7 +21,7 @@ def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() ) self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) ) def forward(self, in_tensor): - avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) ) + avg_pool = F.avg_pool2d( in_tensor, (in_tensor.size(2), in_tensor.size(3)), stride=in_tensor.size(2) ) return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) class SpatialGate(nn.Module):