-
Notifications
You must be signed in to change notification settings - Fork 0
/
kernel.go
105 lines (85 loc) · 2.23 KB
/
kernel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package robonet
import (
"errors"
"fmt"
"log"
)
// Kernel represets a basic conv kernel
type Kernel struct {
Volume
}
//NewKernel creates a new kernel initialized with zeros
func NewKernel(r, c, d int) Kernel {
if !Odd3Dim(r, c, d) {
log.Fatal(errors.New("Kernel must have odd width and heigth"))
}
g := Kernel{New(r, c, d)}
return g
}
//Equals compares to kernels
func (kern *Kernel) Equals(in Kernel) bool {
return kern.Volume.Equals(in.Volume)
}
//NewKernelRandom creates a new kernel initialized with random values
func NewKernelRandom(r, c, d int) Kernel {
if !Odd3Dim(r, c, d) {
log.Fatal(errors.New("Kernel must have odd width and heigth"))
}
g := Kernel{NewRand(r, c, d)}
return g
}
//NewKernelFilled creates a new kernel initialized with random values
func NewKernelFilled(r, c, d int, fil float64) Kernel {
if !Odd3Dim(r, c, d) {
log.Fatal(errors.New("Kernel must have odd width and heigth"))
}
g := Kernel{NewFull(r, c, d, fil)}
return g
}
//Apply applys the kernel to a equally sized chunk of a volume
//Only kernels of the same size as the volume can be applied
func (kern Kernel) Apply(in Volume) float64 {
ConvResult := 0.0
r, c, d := kern.Shape()
if !(kern.Volume.EqualSize(in)) {
if in == nil {
panic("nil")
fmt.Printf("kernel: %vx%vx%v vol: nil", r, c, d)
log.Fatal(errors.New("Kernel size doesn't match input "))
}
fmt.Printf("kernel: %vx%vx%v, vol: %vx%vx%v", r, c, d, in.Rows(), in.Collumns(), in.Depth())
log.Fatal(errors.New("Kernel size doesn't match input "))
}
// 1) reflect kernel
kernRef := kern
kernRef.PointReflect()
// 2) multiply pairwise
res := New(r, c, d)
res.SetAll(kern.Volume)
res.MulElem(in)
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
for k := 0; k < d; k++ {
// TODO check if normalization is needed!
ConvResult += (res.GetAt(i, j, k))
}
}
}
return ConvResult
}
//Elems returns the number of element a kernel has
func (kern Kernel) Elems() int {
return kern.Volume.Elems()
}
//Sum returns the sum of all elements in the kernel
func (kern Kernel) Sum() float64 {
res := 0.0
for r := 0; r < kern.Rows(); r++ {
for c := 0; c < kern.Collumns(); c++ {
for d := 0; d < kern.Depth(); d++ {
res += kern.GetAt(r, c, d)
}
}
}
return res
}