Skip to content

Commit

Permalink
focus on tabular case
Browse files Browse the repository at this point in the history
  • Loading branch information
chelate committed Dec 14, 2023
1 parent 5ab4dc8 commit 200bf21
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ uuid = "bbcca2e2-8668-413d-ba81-c82c0d64d1df"
authors = ["chelate <42802644+chelate@users.noreply.github.com> and contributors"]
version = "1.0.0-DEV"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
julia = "1.7"

Expand Down
87 changes: 87 additions & 0 deletions src/PIQL.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,92 @@
module PIQL
export ControlProblem, average_reward



"""
ControlProblem is a struct with fields that are
the functions which completely define a KL-control problem
"""
struct ControlProblem{AA, U, P, R, PA, T, W}
action_space::AA # something that we can iterate over
action_prior::U # π(s,a) -> Float64 exactly like energy, assumed to be normalized
propagator::P # p(x0, a) -> x1 ("random" state)
reward_function::R # r(x0, a, x1) -> reward ::Float64
# given in entropic units already
propagator_average::PA # (s,a,f) -> K·f
terminal_condition::T # T(x) -> bdol
initial_state::W # W() -> x0 generates inital states of interest
γ::Float64 # positive number less than one discount over time
end
# Write your package code here.


"""
the averaged reward
"""
function average_reward(ctrl,s,a)
ctrl.propagator_average(s,a, s1 -> ctrl.reward_function(s,a,s1))
end


"""
generate `controlv` function.
"""
function controlV(ctrl, V, s, a)
ctrl.γ*ctrl.propagator_average(s,a,V) - value_function(s) +
average_reward(ctrl,s,a)
end

"""
the normalization for an unnormalized q function, function of state
"""
function Qnormalization(ctrl, Q, s)
z = 0.0
for a in action_space
z += ctrl.prior(s,a)*exp(Q(s, a))
end
return log(z)
end

"""
"""
function controlQ(ctrl, Q, s, a)
Q(s,a) - Qnormalization(ctrl, Q, s)
end

function action_probability(ctrl, Q, s, a)
ctrl.prior(s,a) * exp(controlQ(ctrl,Q,s,a))
end

function fitness(ctrl,V, Q, s, a)
controlV(ctrl, V, s, a) - controlQ(ctrl, Q, s, a)
end

"""
this is used to generate the optimal control directly, it is the backup operator fo rhte bellman equation
see eq. 7 in paper
"""
function truevalue_recursion(ctrl, s, ν)
z = 0.0
for a in action_space
z += ctrl.prior(s,a)*exp(ctrl.γ * ctrl.propagator_average(s, a, ν) + average_reward(ctrl, s, a))
end
return log(z)
end

"""
this is used to detemrine Z
"""
function z_recursion(ctrl, Q, V, s, Z; β= 1.0)
out = 0.0
for a in action_space
out += action_probability(ctrl, Q, s, a) * exp* fitness(ctrl, V, Q, s, a) ) * propagator_average(ctrl, s, a, Z)
end
return out
end

include("value_iteration.jl")
include("Problems/BinaryBandit.jl")
include("Problems/GridWorld.jl")

end
68 changes: 68 additions & 0 deletions src/Problems/BinaryBandit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
export binary_bandit_problem

using Distributions

# the state structure is
# state = (hi,ti),...,rounds_left)

struct BanditState{I<:AbstractVector,R<:Integer}
belief::I
rounds::R
end

function observe(bs,a)
(h,t) = bs.belief[a]
rand(Bernoulli(h/(h+t)))
end

function update(bs,a,o)
belief = copy(bs.belief)
(h,t) = belief[a]
belief[a] = (h + o, t + !o)
BanditState(belief, bs.rounds -1)
end

function reachable(bs)
[update(bs,a,o) for a in 1:length(bs.belief) for o in [true,false]]
end


function binary_bandit_problem(;rounds = 10, ncoins = 2, prior = (1,1), payoff = 1.0, γ = 1.0)
action_space = 1:ncoins
action_prior(s,a) = 1.0 / ncoins
function propagator(s0,a)
o = observe(s0, a)
return update(s0, a, o)
end
function reward(s0, a, s1)
(s1.belief[a][1] - s0.belief[a][1])*payoff
end
function propagator_average(s,a,f)
(h,t) = s.belief[a]
f(update(s, a, false))*(t/(h+t)) + f(update(s, a, true))*(h/(h+t))
end
term(bs) = bs.rounds < 1
initial_state() = BanditState([prior for ii in 1:ncoins],rounds)
ControlProblem(
action_space,
action_prior,
propagator,
reward,
propagator_average,
term, initial_state, γ)
end


function state_iterator(;rounds = 10, ncoins = 2, prior = (1,1))
out = [BanditState([prior for ii in 1:ncoins],rounds)]
cur = out
for _ in 1:rounds-1
next = similar(cur,0)
for state in cur
push!(next, reachable(state)...)
end
vcat(out,next)
cur = next
end
return Iterators.reverse(out)
end
119 changes: 119 additions & 0 deletions src/Problems/GridWorld.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@

export make_gridworld, make_ctrl
# struct ControlProblem{A, U, P, R, PA, T, W}
# action_space::Vector{A} # something that we can iterate over
# action_prior::U # π(s,a) -> Float64 exactly like energy
# propagator::P # p(x0, a) -> x1 ("random" state)
# reward_function::R # r(x0, a, x1) -> reward ::Float64
# # given in entropic units already
# propagator_average::PA # (s,a,f) -> K·f
# terminal_condition::T # T(x) -> bdol
# initial_state::W # W() -> x0 generates inital states of interest
# γ::Float64 # positive number less than one discount over time
# end

using StatsBase

"""
Actions are indexed by natural numbers, 0-3 in 2D
"""
function step_choices(; dim::Val{G} = Val(2)) where G
0:2*G-1
end

function step(a; dim::Val{G} = Val(2)) where G
CartesianIndex(ntuple(ii -> (1-2*mod(a,2))*(div(a,2) == ii-1) , G))
end


function take_step(x0::G, a; dim = Val(2), walls = Bool[]) where G
x1::G = x0 + step(a;dim)
return ifelse(walls[x1], x0, x1)
end


struct Gridworld
walls::Array{Bool}
goal::CartesianIndex
end

function make_walls(size; density = 0.1)
walls = falses((size .+ 2)...)
walls[begin,:] .= true
walls[end,:] .= true
walls[:,begin] .= true
walls[:,end] .= true
for ii in eachindex(walls)
if rand() < density
walls[ii] = true # put down random barriers
end
end
return walls
end

function draw_not_wall(walls)
while true
ii = rand(LinearIndices(walls))
if !walls[ii]
return CartesianIndices(walls)[ii]
end
end
end

function get_reachable(walls, goal)
reachable_set = Set([goal])
boundary = Set([goal])
dim = Val(ndims(walls))
while !isempty(boundary)
union!(reachable_set, boundary)
boundary = reduce(union,
Set(ii + s
for s in step.(step_choices(; dim); dim) if
!in(ii + s, reachable_set) & !walls[(ii + s)] )
for ii in boundary)
end
return reachable_set
end

function make_gridworld(size; density = 0.1)
walls = make_walls(size; density)
goal = draw_not_wall(walls)
reachable_set = get_reachable(walls, goal)
for ii in CartesianIndices(walls)
if !in(ii,reachable_set)
walls[ii] = true
end
end
return Gridworld(walls,goal)
end

function make_ctrl(gw::Gridworld;
randomness = 0.1, # likelihood of choosing a random action
reward_scale = 0.2, # temperature
step_cost = 0.1*reward_scale, # entropic cost of a step
reward = 1*reward_scale, # reward of reaching the end
γ = 0.99) # discounting rate
let dim = Val(ndims(gw.walls)), len = length(step_choices(;dim)), walls = gw.walls
# attempt to make the closure fast
function pa(s,a,f) # propagator average
out::Float64 = 0.0
for aa in step_choices(;dim)
out += ((a == aa)*(1-randomness) + (randomness / len)) * f(take_step(s,aa; dim, walls))
end
out
end
ControlProblem(
step_choices(;dim),
(s,a)->1.0 / len,
(s,a) -> take_step(s,
ifelse(rand() < randomness, sample(step_choices(;dim)),a);
dim, walls),
(s0,a,s1) -> ifelse(s1 == gw.goal, reward, - step_cost),
pa,
s -> s == gw.goal ,
() -> draw_not_wall(gw.walls),
γ) end
end

state_iterator(gw::Gridworld) = filter(ii -> !gw.walls[ii] & (ii != gw.goal),
CartesianIndices(gw.walls))
54 changes: 54 additions & 0 deletions src/value_iteration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@


# Pure value iteration


function update_z!(z_dict, ctrl, v_dict; β = 1.0)
diff = 0.0 # running difference
for s0 in shuffle(keys(z_dict))
if ctrl.terminal_condition(s0)
new = exp(0.0 - v_dict[s0])
else
new = z_recursion(ctrl,
(s,a) -> controlV(ctrl, s -> vdict[s], s, a), # V-determined Q-funciton
s -> v_dict[s], s0, s -> z_dict[s]; β)
end
diff += (log(new) - log(Z[s0]))^2
Z[s0] = new
end
end

function generate_z(ctrl, v_dict)
z_dict = Dict(k => 1.0 for k in keys(v_dict))
while true
diff = update_z(z_dict, ctrl, v_dict)
if diff < 10^(-5)
break
end
end
return z_dict
end

function update_ν!(ν_dict, ctrl)
diff = 0.0 # running difference
for s0 in shuffle(keys(z_dict))
if ctrl.terminal_condition(s0)
new = 0.0
else
new = truevalue_recursion(ctrl, s, s -> ν_dict[s])
end
diff += (log(new) - log(ν_dict[s0]))^2
Z[s0] = new
end
end

function generate_ν(ctrl, v_dict)
z0 = Dict(k => 1.0 for k in keys(v_dict))
while true
diff = update_ν(ν_dict, ctrl)
if diff < 10^(-5)
break
end
end
return z0
end
17 changes: 16 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
using PIQL
using Test

s = (20,20)
gw = make_gridworld(s; density = 0.1)
gwctrl = make_ctrl(gw; randomness = 0.1)
initial_value_dict = Dict(s => 0.0 for s in state_iterator(gw))


bbctrl = binary_bandit_problem(rounds = 20, ncoins = 2, prior = (1,1), payoff = 1.0, γ = 1.0)

@testset "PIQL.jl" begin
# Write your tests here.
@test gw.walls[gw.goal] == false
@test gw.walls[gwctrl.initial_state()] == false

i0 = bbctrl.initial_state()
i1 = bbctrl.propagator(i0,1)
@test i0.rounds - i1.rounds == 1
@test average_reward(bbctrl,i0, 1) == 1/2 # average reward is half payoff
@test average_reward(bbctrl,i1, 1) == i1.belief[1][1]/sum(i1.belief[1]) # average reward is half payoff
end

0 comments on commit 200bf21

Please sign in to comment.