-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsoftmax_test.go
132 lines (115 loc) · 2.9 KB
/
softmax_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package bandit
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSoftmax_New(t *testing.T) {
assert := assert.New(t)
tests := []struct {
epsilon float64
counts []int
rewards []float64
err error
}{
{0.1, nil, nil, nil},
{0.1, []int{0, 0, 0}, []float64{0.0, 0.0, 0.0}, nil},
{-0.1, nil, nil, ErrInvalidTemperature},
{1.1, nil, nil, nil},
{1.1, []int{0, 0}, nil, ErrInvalidLength},
{1.0, []int{0, 0}, nil, ErrInvalidLength},
{1.0, nil, []float64{0.0}, ErrInvalidLength},
{1.0, []int{0, 0, 0, 0, 0}, []float64{0.0}, ErrInvalidLength},
}
for i, tt := range tests {
_, err := NewSoftmax(tt.epsilon, tt.counts, tt.rewards)
if tt.err != nil {
assert.Equal(tt.err, err, "should throw the correct error for test %d", i+1)
} else {
assert.Nil(err)
}
}
}
func TestSoftmax_Init(t *testing.T) {
assert := assert.New(t)
tests := []struct {
arms int
err error
}{
{-1, ErrInvalidArms},
{0, ErrInvalidArms},
{1, nil},
{3, nil},
{5, nil},
}
for _, tt := range tests {
softmax, err := NewSoftmax(0.1, nil, nil)
err = softmax.Init(tt.arms)
if tt.err != nil {
assert.Equal(err, tt.err, "should throw error for invalid arms length")
} else {
assert.Nil(err)
assert.Equal(tt.arms, len(softmax.Counts), "counts should be of equal length with arm")
assert.Equal(tt.arms, len(softmax.Rewards), "rewards should be of equal length with arm")
}
}
}
func TestSoftmax_SelectArm(t *testing.T) {
assert := assert.New(t)
b, err := NewSoftmax(0.1, nil, nil)
assert.Nil(err)
b.Init(3)
arm := b.SelectArm(0.1)
assert.Equal(arm, 0, "should select the unplayed arm")
b.Update(arm, 1.0)
}
func TestSoftmax_UpdateArm(t *testing.T) {
assert := assert.New(t)
b, err := NewSoftmax(0.1, nil, nil)
assert.Nil(err)
err = b.Init(3)
assert.Nil(err)
tests := []struct {
arm int
reward float64
expectedCounts int
expectedReward float64
}{
{0, 1.0, 1, 1.0},
{0, 0.0, 2, 0.5},
{0, 1.0, 3, 2.0 / 3.0},
{0, 1.0, 4, 0.75},
}
for i, tt := range tests {
b.Update(tt.arm, tt.reward)
assert.Equal(tt.expectedCounts, b.Counts[tt.arm], "counts should be equal for test %d", i+1)
assert.Equal(tt.expectedReward, b.Rewards[tt.arm], "rewards should be equal for test %d", i+1)
}
}
func TestSoftmax_UpdateArmWithInvalidParams(t *testing.T) {
assert := assert.New(t)
b, err := NewSoftmax(0.1, nil, nil)
assert.Nil(err)
tests := []struct {
arms int
chosenArm int
reward float64
err error
}{
{1, 3, 0.0, ErrArmsIndexOutOfRange},
{3, 3, 0.0, ErrArmsIndexOutOfRange},
{3, 2, 1.0, nil},
{3, 2, 0.0, nil},
{3, 2, -1.0, ErrInvalidReward},
{3, 3, -1.0, ErrArmsIndexOutOfRange},
}
for _, tt := range tests {
err = b.Init(tt.arms)
assert.Nil(err)
err = b.Update(tt.chosenArm, tt.reward)
if tt.err != nil {
assert.Equal(tt.err, err, "should throw the correct error")
} else {
assert.Nil(err)
}
}
}