-
Notifications
You must be signed in to change notification settings - Fork 0
/
pool_layer_test.go
88 lines (77 loc) · 2.09 KB
/
pool_layer_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package robonet
import (
"reflect"
"testing"
)
func Test_maxPool(t *testing.T) {
tests := []struct {
name string
vol Volume
wantRes []float64
}{
{"All Zeros", New(10, 10, 5), []float64{0, 0, 0, 0, 0}},
{"testVol", testVol, []float64{8, 17, 26}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotRes := maxPool(tt.vol); !reflect.DeepEqual(gotRes, tt.wantRes) {
t.Errorf("maxPool() = %v, want %v", gotRes, tt.wantRes)
}
})
}
}
func TestPoolLayer_Calculate(t *testing.T) {
res1 := New(2, 2, 3)
res1.SetAt(0, 0, 0, 4)
res1.SetAt(0, 1, 0, 5)
res1.SetAt(1, 0, 0, 7)
res1.SetAt(1, 1, 0, 8)
res1.SetAt(0, 0, 1, 13)
res1.SetAt(0, 1, 1, 14)
res1.SetAt(1, 0, 1, 16)
res1.SetAt(1, 1, 1, 17)
res1.SetAt(0, 0, 2, 22)
res1.SetAt(0, 1, 2, 23)
res1.SetAt(1, 0, 2, 25)
res1.SetAt(1, 1, 2, 26)
res2 := New(1, 1, 3)
res2.SetAt(0, 0, 0, 8)
res2.SetAt(0, 0, 1, 17)
res2.SetAt(0, 0, 2, 26)
type fields struct {
SizeR int
SizeC int
StrideR int
StrideC int
}
tests := []struct {
name string
vol Volume
fields fields
want Volume
}{
{"All Zeros stride 2 size 2", New(6, 6, 3), fields{2, 2, 2, 2}, New(3, 3, 3)},
{"All Zeros stride 3 size 3", New(6, 6, 3), fields{3, 3, 3, 3}, New(2, 2, 3)},
{"All Zeros stride 6 size 6", New(6, 6, 3), fields{6, 6, 6, 6}, New(1, 1, 3)},
{"All Zeros stride 2 size 4", New(6, 6, 3), fields{4, 4, 2, 2}, New(2, 2, 3)},
{"All Zeros stride 6 size 4", New(10, 10, 3), fields{4, 4, 6, 6}, New(2, 2, 3)},
{"All Zeros stride 5 size 2", New(10, 10, 3), fields{2, 2, 5, 5}, New(2, 2, 3)},
{"testVol stride 5 size 2", testVol, fields{2, 2, 1, 1}, res1},
{"testVol stride 5 size 2", testVol, fields{3, 3, 1, 1}, res2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lay := &PoolLayer{
SizeR: tt.fields.SizeR,
SizeC: tt.fields.SizeC,
StrideR: tt.fields.StrideR,
StrideC: tt.fields.StrideC,
}
lay.Input(tt.vol)
lay.Calculate()
if got := lay.Output(); !got.Equals(tt.want) {
t.Errorf("PoolLayer.Calculate() = %v, want %v", got, tt.want)
}
})
}
}