-
Notifications
You must be signed in to change notification settings - Fork 2
/
corr1d.py
33 lines (26 loc) · 1.51 KB
/
corr1d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import tensorflow as tf
_correlation1d_ops = tf.load_op_library(tf.resource_loader.get_path_to_datafile('./correlation1d.so'))
def correlation1d(input_a, input_b, kernel_size, max_displacement, stride1, stride2, padding):
return _correlation1d_ops.correlation1d(input_a,
input_b,
kernel_size,
max_displacement,
stride1,
stride2,
padding)
@tf.RegisterGradient('Correlation1d')
def _correlation_grad(corr1d_op, gradients):
kernel_size = corr1d_op.get_attr('kernel_size')
max_displacement = corr1d_op.get_attr('max_displacement')
stride_1 = corr1d_op.get_attr('stride_1')
stride_2 = corr1d_op.get_attr('stride_2')
pad = corr1d_op.get_attr('pad')
corr1d_grads = _correlation1d_ops.correlation1d_grad(gradients,
corr1d_op.inputs[0],
corr1d_op.inputs[1],
kernel_size,
max_displacement,
stride_1,
stride_2,
pad)
return corr1d_grads.backpros_a, corr1d_grads.backpros_b