Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobWang95 committed Feb 19, 2020
1 parent dc347f8 commit 2d8f9ac
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
30 changes: 19 additions & 11 deletions DCF.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.stride = stride
self.padding = padding
self.kernel_list = {}
self.num_bases = num_bases
assert mode in ['mode0', 'mode1'], 'Only mode0 and mode1 are available at this moment.'
self.mode = mode
self.bases_grad = bases_grad
Expand Down Expand Up @@ -111,10 +110,10 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
else:
self.register_parameter('bias', None)
self.reset_parameters()

self.num_bases = num_bases
if self.mode == 'mode1':
self.weight.data = self.weight.data.view(out_channels*in_channels, num_bases)
self.bases.data = self.bases.data.view(num_bases, kernel_size*kernel_size)
self.weight.data = self.weight.data.view(out_channels, in_channels, num_bases)
self.bases.data = self.bases.data.view(num_bases, kernel_size, kernel_size)
self.forward = self.forward_mode1
else:
self.forward = self.forward_mode0
Expand All @@ -128,24 +127,25 @@ def reset_parameters(self):
self.bias.data.zero_()

def forward_mode0(self, input):
FE_SIZE = input.size()
N, C, H, W = input.size()
feature_list = []
input = input.view(FE_SIZE[0]*FE_SIZE[1], 1, FE_SIZE[2], FE_SIZE[3])
input = input.view(N*C, 1, H, W)

feature = F.conv2d(input, self.bases,
None, self.stride, self.padding, dilation=self.dilation)

H = int((H-self.kernel_size+2*self.padding)/self.stride+1)
W = int((W-self.kernel_size+2*self.padding)/self.stride+1)

feature = feature.view(
FE_SIZE[0], FE_SIZE[1]*self.num_bases,
int((FE_SIZE[2]-self.kernel_size+2*self.padding)/self.stride+1),
int((FE_SIZE[3]-self.kernel_size+2*self.padding)/self.stride+1))
N, C*self.num_bases, H, W)

feature_out = F.conv2d(feature, self.weight, self.bias, 1, 0)

return feature_out

def forward_mode1(self, input):
rec_kernel = torch.mm(self.weight, self.bases).view(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
rec_kernel = torch.einsum('abc,cdf->abdf', self.weight, self.bases)

feature = F.conv2d(input, rec_kernel,
self.bias, self.stride, self.padding, dilation=self.dilation)
Expand All @@ -154,4 +154,12 @@ def forward_mode1(self, input):

def extra_repr(self):
return 'kernel_size={kernel_size}, stride={stride}, padding={padding}, num_bases={num_bases}' \
', bases_grad={bases_grad}, mode={mode}'.format(**self.__dict__)
', bases_grad={bases_grad}, mode={mode}'.format(**self.__dict__)


if __name__ == '__main__':
conv = Conv_DCF(10, 20, 3)
conv2 = Conv_DCF(10, 20, 3, mode='mode0')
data = torch.randn(2, 10, 16, 16)
print(conv(data).shape)
print(conv2(data).shape)
4 changes: 2 additions & 2 deletions ImageClassification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
name_file = sys.argv[0]
if os.path.exists(LOG_DIR): shutil.rmtree(LOG_DIR)
os.mkdir(LOG_DIR)
os.mkdir(LOG_DIR + '/train_img')
os.mkdir(LOG_DIR + '/test_img')
# os.mkdir(LOG_DIR + '/train_img')
# os.mkdir(LOG_DIR + '/test_img')
os.mkdir(LOG_DIR + '/files')
os.system('cp %s %s' % (name_file, LOG_DIR))
os.system('cp %s %s' % ('*.py', os.path.join(LOG_DIR, 'files')))
Expand Down

0 comments on commit 2d8f9ac

Please sign in to comment.