Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add actionvalues for ValueIterationPolicy #22

Merged
merged 1 commit into from
Oct 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/vanilla.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,12 @@ function value(policy::ValueIterationPolicy, s::S) where S
sidx = stateindex(policy.mdp, s)
policy.util[sidx]
end

function POMDPPolicies.actionvalues(policy::ValueIterationPolicy, s::S) where S
if !policy.include_Q
error("ValueIterationPolicyError: the policy does not contain the Q function!")
else
sidx = stateindex(policy.mdp, s)
return policy.qmat[sidx,:]
end
end
43 changes: 22 additions & 21 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
using DiscreteValueIteration
using POMDPModels
using POMDPModelTools
using POMDPs
using Test

# Test basic value iteration functionality
@testset "all" begin
@testset "policy" begin
include("test_value_iteration_policy.jl") # test the policy object
end
@testset "basic" begin
include("test_basic_value_iteration.jl") # then the creation of a policy
end
@testset "basic disallowing actions" begin
include("test_basic_value_iteration_disallowing_actions.jl") # then a complex form where states determine actions
end

println("Testing Requirements")
@requirements_info ValueIterationSolver() LegacyGridWorld()
end
using DiscreteValueIteration
using POMDPModels
using POMDPModelTools
using POMDPPolicies
using POMDPs
using Test

# Test basic value iteration functionality
@testset "all" begin
@testset "policy" begin
include("test_value_iteration_policy.jl") # test the policy object
end
@testset "basic" begin
include("test_basic_value_iteration.jl") # then the creation of a policy
end
@testset "basic disallowing actions" begin
include("test_basic_value_iteration_disallowing_actions.jl") # then a complex form where states determine actions
end

println("Testing Requirements")
@requirements_info ValueIterationSolver() LegacyGridWorld()
end
194 changes: 97 additions & 97 deletions test/test_basic_value_iteration.jl
Original file line number Diff line number Diff line change
@@ -1,98 +1,98 @@
using DelimitedFiles

function support_serial_qtest(mdp::Union{MDP,POMDP}, file::AbstractString; niter::Int64=100, res::Float64=1e-3)
qt = readdlm(file)
solver = ValueIterationSolver(max_iterations=niter, belres=res, verbose=true)
policy = solve(solver, mdp)
(q, u, p, am) = locals(policy)
npolicy = ValueIterationPolicy(mdp, deepcopy(q))
nnpolicy = ValueIterationPolicy(mdp, deepcopy(q), deepcopy(u), deepcopy(p))
s = GridWorldState(1,1)
a1 = action(policy, s)
v1 = value(policy, s)
a2 = action(npolicy, s)
v2 = value(npolicy, s)
return (isapprox(qt, q, rtol=1e-5)) && (policy.policy == nnpolicy.policy)
end


function test_complex_gridworld()
# Load correct policy from file and verify we can reconstruct it
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = LegacyGridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
file = "grid-world-10x10-Q-matrix.txt"
niter = 100
res = 1e-3

return support_serial_qtest(mdp, file, niter=niter, res=res)
end

function test_simple_grid()
# Simple test....
# GridWorld(sx=2,sy=3) w reward at (2,3):
# Here's our grid:
# |state (x,y)____available actions__|
# ----------------------------------------------
# |5 (1,3)__u,d,l,r__|6 (2,3)__u,d,l,r+REWARD__|
# |3 (1,2)__u,d,l,r__|4 (2,2)______u,d,l,r_____|
# |1 (1,1)__u,d,l,r__|2 (2,1)______u,d,l,r_____|
# ----------------------------------------------
# 7 (0,0) is absorbing state
mdp = LegacyGridWorld(sx=2, sy=3, rs = [GridWorldState(2,3)], rv = [10.0])

solver = ValueIterationSolver(verbose=true)
policy = solve(solver, mdp)

# up: 1, down: 2, left: 3, right: 4
correct_policy = [1,1,1,1,4,1,1] # alternative policies
# are possible, but since they are tied & the first
# action is always 1, we will always return 1 for tied
# actions
return policy.policy == correct_policy
end

function test_init_solution()
# Initialize the value to the solution
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = LegacyGridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
qt = readdlm("grid-world-10x10-Q-matrix.txt")
ut = maximum(qt, dims=2)[:]
solver = ValueIterationSolver(verbose=true, init_util=ut, belres=1e-3)
policy = solve(solver, mdp)
return isapprox(ut, policy.util, rtol=1e-5)
end

function test_not_include_Q()
# Load correct policy from file and verify we can reconstruct it
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = LegacyGridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
qt = readdlm("grid-world-10x10-Q-matrix.txt")
ut = maximum(qt, dims=2)[:]
niter = 100
res = 1e-3
solver = ValueIterationSolver(verbose=true, init_util=ut, belres=1e-3, include_Q=false)
policy = solve(solver, mdp)
return isapprox(ut, policy.util, rtol=1e-3)
end

function test_warning()
mdp = LegacyGridWorld()
solver = ValueIterationSolver()
println("There should be a warning bellow: ")
solve(solver, mdp, verbose=true)
end

@test test_complex_gridworld() == true
@test test_simple_grid() == true
@test test_init_solution() == true
@test test_not_include_Q() == true
using DelimitedFiles
function support_serial_qtest(mdp::Union{MDP,POMDP}, file::AbstractString; niter::Int64=100, res::Float64=1e-3)
qt = readdlm(file)
solver = ValueIterationSolver(max_iterations=niter, belres=res, verbose=true)
policy = solve(solver, mdp)
(q, u, p, am) = locals(policy)
npolicy = ValueIterationPolicy(mdp, deepcopy(q))
nnpolicy = ValueIterationPolicy(mdp, deepcopy(q), deepcopy(u), deepcopy(p))
s = GridWorldState(1,1)
a1 = action(policy, s)
v1 = value(policy, s)
a2 = action(npolicy, s)
v2 = value(npolicy, s)
return (isapprox(qt, q, rtol=1e-5)) && (policy.policy == nnpolicy.policy)
end
function test_complex_gridworld()
# Load correct policy from file and verify we can reconstruct it
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = LegacyGridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
file = "grid-world-10x10-Q-matrix.txt"
niter = 100
res = 1e-3
return support_serial_qtest(mdp, file, niter=niter, res=res)
end
function test_simple_grid()
# Simple test....
# GridWorld(sx=2,sy=3) w reward at (2,3):
# Here's our grid:
# |state (x,y)____available actions__|
# ----------------------------------------------
# |5 (1,3)__u,d,l,r__|6 (2,3)__u,d,l,r+REWARD__|
# |3 (1,2)__u,d,l,r__|4 (2,2)______u,d,l,r_____|
# |1 (1,1)__u,d,l,r__|2 (2,1)______u,d,l,r_____|
# ----------------------------------------------
# 7 (0,0) is absorbing state
mdp = LegacyGridWorld(sx=2, sy=3, rs = [GridWorldState(2,3)], rv = [10.0])
solver = ValueIterationSolver(verbose=true)
policy = solve(solver, mdp)
# up: 1, down: 2, left: 3, right: 4
correct_policy = [1,1,1,1,4,1,1] # alternative policies
# are possible, but since they are tied & the first
# action is always 1, we will always return 1 for tied
# actions
return policy.policy == correct_policy
end
function test_init_solution()
# Initialize the value to the solution
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = LegacyGridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
qt = readdlm("grid-world-10x10-Q-matrix.txt")
ut = maximum(qt, dims=2)[:]
solver = ValueIterationSolver(verbose=true, init_util=ut, belres=1e-3)
policy = solve(solver, mdp)
return isapprox(ut, policy.util, rtol=1e-5)
end
function test_not_include_Q()
# Load correct policy from file and verify we can reconstruct it
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = LegacyGridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
qt = readdlm("grid-world-10x10-Q-matrix.txt")
ut = maximum(qt, dims=2)[:]
niter = 100
res = 1e-3
solver = ValueIterationSolver(verbose=true, init_util=ut, belres=1e-3, include_Q=false)
policy = solve(solver, mdp)
return isapprox(ut, policy.util, rtol=1e-3)
end
function test_warning()
mdp = LegacyGridWorld()
solver = ValueIterationSolver()
println("There should be a warning bellow: ")
solve(solver, mdp, verbose=true)
end
@test test_complex_gridworld() == true
@test test_simple_grid() == true
@test test_init_solution() == true
@test test_not_include_Q() == true
test_warning()
134 changes: 67 additions & 67 deletions test/test_basic_value_iteration_disallowing_actions.jl
Original file line number Diff line number Diff line change
@@ -1,67 +1,67 @@
mutable struct SpecialGridWorld <: MDP{GridWorldState, GridWorldAction}
gw::LegacyGridWorld
end

POMDPs.discount(g::SpecialGridWorld) = discount(g.gw)
POMDPs.n_states(g::SpecialGridWorld) = n_states(g.gw)
POMDPs.n_actions(g::SpecialGridWorld) = n_actions(g.gw)
POMDPs.transition(g::SpecialGridWorld, s::GridWorldState, a::GridWorldAction) = transition(g.gw, s, a)
POMDPs.reward(g::SpecialGridWorld, s::GridWorldState, a::GridWorldAction, sp::GridWorldState) = reward(g.gw, s, a, sp)
POMDPs.stateindex(g::SpecialGridWorld, s::GridWorldState) = stateindex(g.gw, s)
POMDPs.actionindex(g::SpecialGridWorld, a::GridWorldAction) = actionindex(g.gw, a)
POMDPs.actions(g::SpecialGridWorld, s::GridWorldState) = actions(g.gw, s)
POMDPs.states(g::SpecialGridWorld) = states(g.gw)
POMDPs.actions(g::SpecialGridWorld) = actions(g.gw)

# Let's extend actions to hard-code & limit to the
# particular feasible actions from each state....
function POMDPs.actions(mdp::SpecialGridWorld, s::GridWorldState)
# up: 1, down: 2, left: 3, right: 4
sidx = stateindex(mdp, s)
if sidx == 1
acts = [GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 2
acts = [GridWorldAction(:up), GridWorldAction(:right)]
elseif sidx == 3
acts = [GridWorldAction(:left), GridWorldAction(:up)]
elseif sidx == 4
acts = [GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 5
acts = [GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 6
acts = [GridWorldAction(:up), GridWorldAction(:down), GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 7
acts = [GridWorldAction(:up), GridWorldAction(:down), GridWorldAction(:left), GridWorldAction(:right)]
end
return acts
end


function test_conditioning_actions_on_state()
# Condition available actions on next state....
# GridWorld(sx=2,sy=3) w reward at (2,3):
#
# Here's our grid, with some actions missing:
#
# |state (x,y)____available actions__|
# |5 (1,3)__l,r__|6 (2,3)__u,d,l,r+REWARD__|
# |3 (1,2)__l,u__|4 (2,2)_______l,r________|
# |1 (1,1)__l,r__|2 (2,1)_______u,r________|
# 7 (0,0) is absorbing state
mdp = SpecialGridWorld(LegacyGridWorld(sx=2, sy=3, rs = [GridWorldState(2,3)], rv = [10.0]))

solver = ValueIterationSolver(verbose=true)
policy = solve(solver, mdp)

println(policy.policy)

# up: 1, down: 2, left: 3, right: 4
correct_policy = [4,1,1,3,4,1,1] # alternative policies
# for state 6 are possible, but since they are ordered
# such that 1 comes first, 1 will always be the policy
return policy.policy == correct_policy
end

@test test_conditioning_actions_on_state() == true

println("Finished tests")
mutable struct SpecialGridWorld <: MDP{GridWorldState, GridWorldAction}
gw::LegacyGridWorld
end
POMDPs.discount(g::SpecialGridWorld) = discount(g.gw)
POMDPs.n_states(g::SpecialGridWorld) = n_states(g.gw)
POMDPs.n_actions(g::SpecialGridWorld) = n_actions(g.gw)
POMDPs.transition(g::SpecialGridWorld, s::GridWorldState, a::GridWorldAction) = transition(g.gw, s, a)
POMDPs.reward(g::SpecialGridWorld, s::GridWorldState, a::GridWorldAction, sp::GridWorldState) = reward(g.gw, s, a, sp)
POMDPs.stateindex(g::SpecialGridWorld, s::GridWorldState) = stateindex(g.gw, s)
POMDPs.actionindex(g::SpecialGridWorld, a::GridWorldAction) = actionindex(g.gw, a)
POMDPs.actions(g::SpecialGridWorld, s::GridWorldState) = actions(g.gw, s)
POMDPs.states(g::SpecialGridWorld) = states(g.gw)
POMDPs.actions(g::SpecialGridWorld) = actions(g.gw)
# Let's extend actions to hard-code & limit to the
# particular feasible actions from each state....
function POMDPs.actions(mdp::SpecialGridWorld, s::GridWorldState)
# up: 1, down: 2, left: 3, right: 4
sidx = stateindex(mdp, s)
if sidx == 1
acts = [GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 2
acts = [GridWorldAction(:up), GridWorldAction(:right)]
elseif sidx == 3
acts = [GridWorldAction(:left), GridWorldAction(:up)]
elseif sidx == 4
acts = [GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 5
acts = [GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 6
acts = [GridWorldAction(:up), GridWorldAction(:down), GridWorldAction(:left), GridWorldAction(:right)]
elseif sidx == 7
acts = [GridWorldAction(:up), GridWorldAction(:down), GridWorldAction(:left), GridWorldAction(:right)]
end
return acts
end
function test_conditioning_actions_on_state()
# Condition available actions on next state....
# GridWorld(sx=2,sy=3) w reward at (2,3):
#
# Here's our grid, with some actions missing:
#
# |state (x,y)____available actions__|
# |5 (1,3)__l,r__|6 (2,3)__u,d,l,r+REWARD__|
# |3 (1,2)__l,u__|4 (2,2)_______l,r________|
# |1 (1,1)__l,r__|2 (2,1)_______u,r________|
# 7 (0,0) is absorbing state
mdp = SpecialGridWorld(LegacyGridWorld(sx=2, sy=3, rs = [GridWorldState(2,3)], rv = [10.0]))
solver = ValueIterationSolver(verbose=true)
policy = solve(solver, mdp)
println(policy.policy)
# up: 1, down: 2, left: 3, right: 4
correct_policy = [4,1,1,3,4,1,1] # alternative policies
# for state 6 are possible, but since they are ordered
# such that 1 comes first, 1 will always be the policy
return policy.policy == correct_policy
end
@test test_conditioning_actions_on_state() == true
println("Finished tests")
Loading