-
Notifications
You must be signed in to change notification settings - Fork 72
/
ops_theta.py
103 lines (89 loc) · 4.67 KB
/
ops_theta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Function factory for pixel difference convolutional operations with vanilla conv components
please see line 49, the theta parameter was also used in "Yu et al, Searching central difference convolutional networks for face anti-spoofing, CVPR 2020"
Author: Zhuo Su
Date: Dec 29, 2021
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv2d(nn.Module):
def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
super(Conv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.pdc = pdc
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
## cd, ad, rd convolutions
## theta could be used to control the vanilla conv components
## theta = 0 reduces the function to vanilla conv, theta = 1 reduces the fucntion to pure pdc (used in the paper)
def createConvFunc(op_type, theta):
assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
if op_type == 'cv':
return F.conv2d
assert theta > 0 and theta <= 1.0, 'theta should be within (0, 1]'
if op_type == 'cd':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
assert padding == dilation, 'padding for cd_conv set wrong'
weights_c = weights.sum(dim=[2, 3], keepdim=True) * theta
yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y - yc
return func
elif op_type == 'ad':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
assert padding == dilation, 'padding for ad_conv set wrong'
shape = weights.shape
weights = weights.view(shape[0], shape[1], -1)
weights_conv = (weights - theta * weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y
return func
elif op_type == 'rd':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
padding = 2 * dilation
shape = weights.shape
if weights.is_cuda:
buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
else:
buffer = torch.zeros(shape[0], shape[1], 5 * 5)
weights = weights.view(shape[0], shape[1], -1)
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] * theta
buffer[:, :, 12] = weights[:, :, 0] * (1 - theta)
buffer = buffer.view(shape[0], shape[1], 5, 5)
y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y
return func
else:
print('impossible to be here unless you force that')
return None