-
Notifications
You must be signed in to change notification settings - Fork 69
/
bandits.jl
107 lines (95 loc) · 3.23 KB
/
bandits.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
using Printf
using Random
using PGFPlots
mutable struct Bandit
θ::Vector{Float64} # true bandit probabilities
end
Bandit(k::Integer) = Bandit(rand(k))
pull(b::Bandit, i::Integer) = rand() < b.θ[i]
numArms(b::Bandit) = length(b.θ)
function _get_string_list_of_percentages(bandit_odds::Vector{R}) where {R<:Real}
strings = map(θ->Printf.@sprintf("%.2f percent", 100θ), bandit_odds)
retval = strings[1]
for i in 2 : length(strings)
retval = retval * ", " * strings[i]
end
retval
end
function banditTrial(b)
for i in 1 : numArms(b)
but=button("Arm $i",value=0)
display(but)
wins=Observable(0)
Interact.@on &but>0 ? (wins[] = wins[]+pull(b,i)) : 0
display(map(s -> Printf.@sprintf("%d wins out of %d tries (%d percent)", wins[], but[], 100*wins[]/but[]), but))
# NOTE: we used to use the latex() wrapper
end
t = togglebuttons(["Hide", "Show"], value="Hide", label="True Params")
display(t)
display(map(v -> v == "Show" ? _get_string_list_of_percentages(b.θ) : "", t))
end
function banditEstimation(b)
B = [button("Arm $i") for i = 1:numArms(b)]
for i in 1 : numArms(b)
but=button("Arm $i",value=0)
display(but)
wins=Observable(0)
Interact.@on &but>0 ? (wins[] = wins[]+pull(b,i)) : 0
display(map(s -> Printf.@sprintf("%d wins out of %d tries (%d percent)", wins[], but[], 100*wins[]/but[]), but))
display(map(s -> begin
w = wins[]
t = but[]
Axis([
Plots.Linear(θ->pdf(Beta(w+1, t-w+1), θ), (0,1), legendentry="Beta($(w+1), $(t-w+1))")
],
xmin=0,xmax=1,ymin=0, width="15cm", height="10cm")
end, but
))
end
t = togglebuttons(["Hide", "Show"], value="Hide", label="True Params")
display(t)
display(map(v -> v == "Show" ? string(b.θ) : "", t))
end
mutable struct BanditStatistics
numWins::Vector{Int}
numTries::Vector{Int}
BanditStatistics(k::Int) = new(zeros(k), zeros(k))
end
numArms(b::BanditStatistics) = length(b.numWins)
function update!(b::BanditStatistics, i::Int, success::Bool)
b.numTries[i] += 1
if success
b.numWins[i] += 1
end
end
# win probability assuming uniform prior
winProbabilities(b::BanditStatistics) = (b.numWins .+ 1)./(b.numTries .+ 2)
abstract type BanditPolicy end
reset!(p::BanditPolicy) = nothing
function simulate(b::Bandit, policy::BanditPolicy; steps = 10)
wins = zeros(Int, steps)
s = BanditStatistics(numArms(b))
reset!(policy)
for step = 1:steps
i = arm(policy, s)
win = pull(b, i)
update!(s, i, win)
wins[step] = win
end
wins
end
function simulateAverage(b::Bandit, policy::BanditPolicy; steps = 10, iterations = 10)
ret = zeros(Int, steps)
for i = 1:iterations
ret .+= simulate(b, policy, steps=steps)
end
ret ./ iterations
end
function learningCurves(b::Bandit, policies; steps=10, iterations=10)
lines = Plots.Linear[]
for (name, policy) in policies
results = simulateAverage(b, policy; steps=steps, iterations=iterations)
push!(lines, Plots.Linear(results, legendentry=name, style="very thick", mark="none"))
end
return lines
end