-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
348 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,92 @@ | ||
module PIQL | ||
export ControlProblem, average_reward | ||
|
||
|
||
|
||
""" | ||
ControlProblem is a struct with fields that are | ||
the functions which completely define a KL-control problem | ||
""" | ||
struct ControlProblem{AA, U, P, R, PA, T, W} | ||
action_space::AA # something that we can iterate over | ||
action_prior::U # π(s,a) -> Float64 exactly like energy, assumed to be normalized | ||
propagator::P # p(x0, a) -> x1 ("random" state) | ||
reward_function::R # r(x0, a, x1) -> reward ::Float64 | ||
# given in entropic units already | ||
propagator_average::PA # (s,a,f) -> K·f | ||
terminal_condition::T # T(x) -> bdol | ||
initial_state::W # W() -> x0 generates inital states of interest | ||
γ::Float64 # positive number less than one discount over time | ||
end | ||
# Write your package code here. | ||
|
||
|
||
""" | ||
the averaged reward | ||
""" | ||
function average_reward(ctrl,s,a) | ||
ctrl.propagator_average(s,a, s1 -> ctrl.reward_function(s,a,s1)) | ||
end | ||
|
||
|
||
""" | ||
generate `controlv` function. | ||
""" | ||
function controlV(ctrl, V, s, a) | ||
ctrl.γ*ctrl.propagator_average(s,a,V) - value_function(s) + | ||
average_reward(ctrl,s,a) | ||
end | ||
|
||
""" | ||
the normalization for an unnormalized q function, function of state | ||
""" | ||
function Qnormalization(ctrl, Q, s) | ||
z = 0.0 | ||
for a in action_space | ||
z += ctrl.prior(s,a)*exp(Q(s, a)) | ||
end | ||
return log(z) | ||
end | ||
|
||
""" | ||
""" | ||
function controlQ(ctrl, Q, s, a) | ||
Q(s,a) - Qnormalization(ctrl, Q, s) | ||
end | ||
|
||
function action_probability(ctrl, Q, s, a) | ||
ctrl.prior(s,a) * exp(controlQ(ctrl,Q,s,a)) | ||
end | ||
|
||
function fitness(ctrl,V, Q, s, a) | ||
controlV(ctrl, V, s, a) - controlQ(ctrl, Q, s, a) | ||
end | ||
|
||
""" | ||
this is used to generate the optimal control directly, it is the backup operator fo rhte bellman equation | ||
see eq. 7 in paper | ||
""" | ||
function truevalue_recursion(ctrl, s, ν) | ||
z = 0.0 | ||
for a in action_space | ||
z += ctrl.prior(s,a)*exp(ctrl.γ * ctrl.propagator_average(s, a, ν) + average_reward(ctrl, s, a)) | ||
end | ||
return log(z) | ||
end | ||
|
||
""" | ||
this is used to detemrine Z | ||
""" | ||
function z_recursion(ctrl, Q, V, s, Z; β= 1.0) | ||
out = 0.0 | ||
for a in action_space | ||
out += action_probability(ctrl, Q, s, a) * exp(β * fitness(ctrl, V, Q, s, a) ) * propagator_average(ctrl, s, a, Z) | ||
end | ||
return out | ||
end | ||
|
||
include("value_iteration.jl") | ||
include("Problems/BinaryBandit.jl") | ||
include("Problems/GridWorld.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
export binary_bandit_problem | ||
|
||
using Distributions | ||
|
||
# the state structure is | ||
# state = (hi,ti),...,rounds_left) | ||
|
||
struct BanditState{I<:AbstractVector,R<:Integer} | ||
belief::I | ||
rounds::R | ||
end | ||
|
||
function observe(bs,a) | ||
(h,t) = bs.belief[a] | ||
rand(Bernoulli(h/(h+t))) | ||
end | ||
|
||
function update(bs,a,o) | ||
belief = copy(bs.belief) | ||
(h,t) = belief[a] | ||
belief[a] = (h + o, t + !o) | ||
BanditState(belief, bs.rounds -1) | ||
end | ||
|
||
function reachable(bs) | ||
[update(bs,a,o) for a in 1:length(bs.belief) for o in [true,false]] | ||
end | ||
|
||
|
||
function binary_bandit_problem(;rounds = 10, ncoins = 2, prior = (1,1), payoff = 1.0, γ = 1.0) | ||
action_space = 1:ncoins | ||
action_prior(s,a) = 1.0 / ncoins | ||
function propagator(s0,a) | ||
o = observe(s0, a) | ||
return update(s0, a, o) | ||
end | ||
function reward(s0, a, s1) | ||
(s1.belief[a][1] - s0.belief[a][1])*payoff | ||
end | ||
function propagator_average(s,a,f) | ||
(h,t) = s.belief[a] | ||
f(update(s, a, false))*(t/(h+t)) + f(update(s, a, true))*(h/(h+t)) | ||
end | ||
term(bs) = bs.rounds < 1 | ||
initial_state() = BanditState([prior for ii in 1:ncoins],rounds) | ||
ControlProblem( | ||
action_space, | ||
action_prior, | ||
propagator, | ||
reward, | ||
propagator_average, | ||
term, initial_state, γ) | ||
end | ||
|
||
|
||
function state_iterator(;rounds = 10, ncoins = 2, prior = (1,1)) | ||
out = [BanditState([prior for ii in 1:ncoins],rounds)] | ||
cur = out | ||
for _ in 1:rounds-1 | ||
next = similar(cur,0) | ||
for state in cur | ||
push!(next, reachable(state)...) | ||
end | ||
vcat(out,next) | ||
cur = next | ||
end | ||
return Iterators.reverse(out) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
|
||
export make_gridworld, make_ctrl | ||
# struct ControlProblem{A, U, P, R, PA, T, W} | ||
# action_space::Vector{A} # something that we can iterate over | ||
# action_prior::U # π(s,a) -> Float64 exactly like energy | ||
# propagator::P # p(x0, a) -> x1 ("random" state) | ||
# reward_function::R # r(x0, a, x1) -> reward ::Float64 | ||
# # given in entropic units already | ||
# propagator_average::PA # (s,a,f) -> K·f | ||
# terminal_condition::T # T(x) -> bdol | ||
# initial_state::W # W() -> x0 generates inital states of interest | ||
# γ::Float64 # positive number less than one discount over time | ||
# end | ||
|
||
using StatsBase | ||
|
||
""" | ||
Actions are indexed by natural numbers, 0-3 in 2D | ||
""" | ||
function step_choices(; dim::Val{G} = Val(2)) where G | ||
0:2*G-1 | ||
end | ||
|
||
function step(a; dim::Val{G} = Val(2)) where G | ||
CartesianIndex(ntuple(ii -> (1-2*mod(a,2))*(div(a,2) == ii-1) , G)) | ||
end | ||
|
||
|
||
function take_step(x0::G, a; dim = Val(2), walls = Bool[]) where G | ||
x1::G = x0 + step(a;dim) | ||
return ifelse(walls[x1], x0, x1) | ||
end | ||
|
||
|
||
struct Gridworld | ||
walls::Array{Bool} | ||
goal::CartesianIndex | ||
end | ||
|
||
function make_walls(size; density = 0.1) | ||
walls = falses((size .+ 2)...) | ||
walls[begin,:] .= true | ||
walls[end,:] .= true | ||
walls[:,begin] .= true | ||
walls[:,end] .= true | ||
for ii in eachindex(walls) | ||
if rand() < density | ||
walls[ii] = true # put down random barriers | ||
end | ||
end | ||
return walls | ||
end | ||
|
||
function draw_not_wall(walls) | ||
while true | ||
ii = rand(LinearIndices(walls)) | ||
if !walls[ii] | ||
return CartesianIndices(walls)[ii] | ||
end | ||
end | ||
end | ||
|
||
function get_reachable(walls, goal) | ||
reachable_set = Set([goal]) | ||
boundary = Set([goal]) | ||
dim = Val(ndims(walls)) | ||
while !isempty(boundary) | ||
union!(reachable_set, boundary) | ||
boundary = reduce(union, | ||
Set(ii + s | ||
for s in step.(step_choices(; dim); dim) if | ||
!in(ii + s, reachable_set) & !walls[(ii + s)] ) | ||
for ii in boundary) | ||
end | ||
return reachable_set | ||
end | ||
|
||
function make_gridworld(size; density = 0.1) | ||
walls = make_walls(size; density) | ||
goal = draw_not_wall(walls) | ||
reachable_set = get_reachable(walls, goal) | ||
for ii in CartesianIndices(walls) | ||
if !in(ii,reachable_set) | ||
walls[ii] = true | ||
end | ||
end | ||
return Gridworld(walls,goal) | ||
end | ||
|
||
function make_ctrl(gw::Gridworld; | ||
randomness = 0.1, # likelihood of choosing a random action | ||
reward_scale = 0.2, # temperature | ||
step_cost = 0.1*reward_scale, # entropic cost of a step | ||
reward = 1*reward_scale, # reward of reaching the end | ||
γ = 0.99) # discounting rate | ||
let dim = Val(ndims(gw.walls)), len = length(step_choices(;dim)), walls = gw.walls | ||
# attempt to make the closure fast | ||
function pa(s,a,f) # propagator average | ||
out::Float64 = 0.0 | ||
for aa in step_choices(;dim) | ||
out += ((a == aa)*(1-randomness) + (randomness / len)) * f(take_step(s,aa; dim, walls)) | ||
end | ||
out | ||
end | ||
ControlProblem( | ||
step_choices(;dim), | ||
(s,a)->1.0 / len, | ||
(s,a) -> take_step(s, | ||
ifelse(rand() < randomness, sample(step_choices(;dim)),a); | ||
dim, walls), | ||
(s0,a,s1) -> ifelse(s1 == gw.goal, reward, - step_cost), | ||
pa, | ||
s -> s == gw.goal , | ||
() -> draw_not_wall(gw.walls), | ||
γ) end | ||
end | ||
|
||
state_iterator(gw::Gridworld) = filter(ii -> !gw.walls[ii] & (ii != gw.goal), | ||
CartesianIndices(gw.walls)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
|
||
|
||
# Pure value iteration | ||
|
||
|
||
function update_z!(z_dict, ctrl, v_dict; β = 1.0) | ||
diff = 0.0 # running difference | ||
for s0 in shuffle(keys(z_dict)) | ||
if ctrl.terminal_condition(s0) | ||
new = exp(0.0 - v_dict[s0]) | ||
else | ||
new = z_recursion(ctrl, | ||
(s,a) -> controlV(ctrl, s -> vdict[s], s, a), # V-determined Q-funciton | ||
s -> v_dict[s], s0, s -> z_dict[s]; β) | ||
end | ||
diff += (log(new) - log(Z[s0]))^2 | ||
Z[s0] = new | ||
end | ||
end | ||
|
||
function generate_z(ctrl, v_dict) | ||
z_dict = Dict(k => 1.0 for k in keys(v_dict)) | ||
while true | ||
diff = update_z(z_dict, ctrl, v_dict) | ||
if diff < 10^(-5) | ||
break | ||
end | ||
end | ||
return z_dict | ||
end | ||
|
||
function update_ν!(ν_dict, ctrl) | ||
diff = 0.0 # running difference | ||
for s0 in shuffle(keys(z_dict)) | ||
if ctrl.terminal_condition(s0) | ||
new = 0.0 | ||
else | ||
new = truevalue_recursion(ctrl, s, s -> ν_dict[s]) | ||
end | ||
diff += (log(new) - log(ν_dict[s0]))^2 | ||
Z[s0] = new | ||
end | ||
end | ||
|
||
function generate_ν(ctrl, v_dict) | ||
z0 = Dict(k => 1.0 for k in keys(v_dict)) | ||
while true | ||
diff = update_ν(ν_dict, ctrl) | ||
if diff < 10^(-5) | ||
break | ||
end | ||
end | ||
return z0 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,21 @@ | ||
using PIQL | ||
using Test | ||
|
||
s = (20,20) | ||
gw = make_gridworld(s; density = 0.1) | ||
gwctrl = make_ctrl(gw; randomness = 0.1) | ||
initial_value_dict = Dict(s => 0.0 for s in state_iterator(gw)) | ||
|
||
|
||
bbctrl = binary_bandit_problem(rounds = 20, ncoins = 2, prior = (1,1), payoff = 1.0, γ = 1.0) | ||
|
||
@testset "PIQL.jl" begin | ||
# Write your tests here. | ||
@test gw.walls[gw.goal] == false | ||
@test gw.walls[gwctrl.initial_state()] == false | ||
|
||
i0 = bbctrl.initial_state() | ||
i1 = bbctrl.propagator(i0,1) | ||
@test i0.rounds - i1.rounds == 1 | ||
@test average_reward(bbctrl,i0, 1) == 1/2 # average reward is half payoff | ||
@test average_reward(bbctrl,i1, 1) == i1.belief[1][1]/sum(i1.belief[1]) # average reward is half payoff | ||
end |