From 07dad864356c401f265d908aa80914e93cfb691b Mon Sep 17 00:00:00 2001 From: HeMuling <74801533+HeMuling@users.noreply.github.com> Date: Thu, 9 Mar 2023 11:44:54 +0800 Subject: [PATCH] Create bam.py update to allow BAM to handle case when input height not equals input width --- MODELS/bam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MODELS/bam.py b/MODELS/bam.py index cbda3d0..b2204b9 100644 --- a/MODELS/bam.py +++ b/MODELS/bam.py @@ -21,7 +21,7 @@ def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() ) self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) ) def forward(self, in_tensor): - avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) ) + avg_pool = F.avg_pool2d( in_tensor, (in_tensor.size(2), in_tensor.size(3)), stride=in_tensor.size(2) ) return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) class SpatialGate(nn.Module):