Skip to content

Commit

Permalink
major clean up!
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Mar 23, 2024
1 parent d665d59 commit f733630
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 285 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Logan-Mondal-Bhamidipaty"]
version = "1.0.0-DEV"

[deps]
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
50 changes: 0 additions & 50 deletions images/baby_benchmark.svg

This file was deleted.

4 changes: 1 addition & 3 deletions paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ v = value(approx_policy, s)
a = action(approx_policy, s)
```

![We see that that the compressed solver recovers a similar value function to SARSOP [@SARSOP].](./images/baby_benchmark.svg){height="200pt"}

## Function Approximators

CompressedBeliefMDPs.jl is compatible with any `LocalFunctionApproximator`. It supports grid interpolations [@grid] through [GridInterpolations.jl](https://github.com/sisl/GridInterpolations.jl) and $k$-nearest neighbors [@kNN] through [NearestNeighbors.jl](https://github.com/KristofferC/NearestNeighbors.jl). For more details, see [LocalFunctionApproximation.jl](https://github.com/sisl/LocalFunctionApproximation.jl)

## Compressors

CompressedBeliefMDPs.jl provides several wrappers for commonly used compressors. Through [MultiVariateStats.jl](https://juliastats.org/MultivariateStats.jl/stable/), we include PCA [@PCA], kernel PCA [@kernelPCA], and probabilistic PCA [@PPCA]. Through a forthcoming package `ExpFamilyPCA.jl`, we also include exponential family PCA [@EPCA]. This is more general than the Poisson exponential family PCA in @Roy: ExpFamilyPCA.jl supports compression objectives induced from _any_ convex function, not just the Poisson link function.
CompressedBeliefMDPs.jl provides several wrappers for commonly used compressors. Through [MultiVariateStats.jl](https://juliastats.org/MultivariateStats.jl/stable/), we include PCA [@PCA], kernel PCA [@kernelPCA], and probabilistic PCA [@PPCA].

# Acknowledgements

Expand Down
19 changes: 9 additions & 10 deletions src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ using Infiltrator

using POMDPs
using POMDPTools
using LocalApproximationValueIteration
using LocalFunctionApproximation

import Lazy: @forward
using Bijections
using NearestNeighbors
using StaticArrays

using LinearAlgebra
using Random
Expand All @@ -27,22 +32,16 @@ include("compressors/compressor.jl")
include("compressors/mv_stats.jl")

export
Sampler,
sample,
# DiscreteEpsGreedySampler,
# DiscreteSoftmaxSampler,
DiscreteRandomSampler
include("samplers/sampler.jl")
include("samplers/utils.jl")
sample
include("sampler.jl")

export
CompressedBeliefMDP,
CompressedBeliefMDPState
include("cbmdp.jl")

export
CompressedSolver,
CompressedSolverPolicy
CompressedBeliefSolver,
CompressedBeliefPolicy,
solve
include("solver.jl")

Expand Down
182 changes: 76 additions & 106 deletions src/cbmdp.jl
Original file line number Diff line number Diff line change
@@ -1,134 +1,104 @@
# TODO: add particle filters --> AbstractParticleBelief https://juliapomdp.github.io/ParticleFilters.jl/latest/beliefs/


"""
struct CompressedBeliefMDP{B, A} <: MDP{B, A}
A struct representing a Markov Decision Process (MDP) with compressed beliefs.
# Fields
- `bmdp::GenerativeBeliefMDP`: The underlying GenerativeBeliefMDP that defines the original MDP.
- `compressor::Compressor`: The compressor used to compress beliefs in the MDP.
## Type Parameters
- `B`: Type of the belief space.
- `A`: Type of the action space.
## Example
```julia
pomdp = TigerPOMDP()
updater = DiscreteUpdater(pomdp)
compressor = PCACompressor(1)
generative_mdp = GenerativeBeliefMDP(pomdp, updater)
mdp = CompressedBeliefMDP(generative_mdp, compressor)
```
"""
struct CompressedBeliefMDP{B, A} <: MDP{B, A}
bmdp::GenerativeBeliefMDP
compressor::Compressor
ϕ::Bijection # ϕ: belief ↦ compress(compressor, belief); NOTE: While compressions aren't usually injective, we cache compressed beliefs on a first-come, first-served basis, so the *cache* is effectively bijective.
end

"""
struct CompressedBeliefMDPState
A struct representing the state of a CompressedBeliefMDP.
# Fields
- `b̃::AbstractArray{<:Real}`: Compressed belief vector.
"""
struct CompressedBeliefMDPState
::AbstractArray{<:Real}
end

# TODO: merge docstring of constructor w/ docstring for the struct proper
"""
CompressedBeliefMDP(pomdp::POMDP, updater::Updater, compressor::Compressor)
Create a CompressedBeliefMDP based on a given POMDP, updater, and compressor.
This function initializes a `GenerativeBeliefMDP` using the provided `pomdp` and `updater`.
It then constructs a `CompressedBeliefMDP` with the specified `compressor`.
# Arguments
- `pomdp::POMDP`: The original partially observable Markov decision process (POMDP).
- `updater::Updater`: The belief updater used in the GenerativeBeliefMDP.
- `compressor::Compressor`: The compressor used to compress beliefs in the MDP.
# Returns
- `CompressedBeliefMDP{CompressedBeliefMDPState, actiontype(bmdp)}`: The compressed belief MDP.
## Example
```julia
pomdp = MyPOMDP() # replace with your specific POMDP type
updater = MyUpdater() # replace with your specific updater type
compressor = MyCompressor() # replace with your specific compressor type
mdp = CompressedBeliefMDP(pomdp, updater, compressor)
```
"""
function CompressedBeliefMDP(pomdp::POMDP, updater::Updater, compressor::Compressor)
# Hack to determine typeof(b̃)
bmdp = GenerativeBeliefMDP(pomdp, updater)
return CompressedBeliefMDP{CompressedBeliefMDPState, actiontype(bmdp)}(bmdp, compressor)
b = initialstate(bmdp).val
= compress(compressor, convert_s(AbstractVector{Float64}, b, bmdp.pomdp))
B = typeof(b)
= typeof(b̃)
ϕ = Bijection{B, B̃}()
return CompressedBeliefMDP{B̃, actiontype(bmdp)}(bmdp, compressor, ϕ)
end


function decode(m::CompressedBeliefMDP, s::CompressedBeliefMDPState)
b = decompress(m.compressor, s.b̃)
b = normalize(abs.(b), 1) # make valid probability distribution
b = convert_s(statetype(m.bmdp), b, m.bmdp.pomdp)
function decode(m::CompressedBeliefMDP, b̃)
b = m.ϕ(b̃)
return b
end

function encode(m::CompressedBeliefMDP, b)
b = convert_s(AbstractArray, b, m.bmdp.pomdp)
= compress(m.compressor, b)
s = CompressedBeliefMDPState(b̃)
return s
b = convert_s(AbstractVector{Float64}, b, m)
= get!(m.ϕ, b) do
b = convert_s(AbstractArray{Float64}, b, m) # TODO: not sure if I need a `let b = ...` here
compress(m.compressor, b) # NOTE: compress is only called if b ∉ domain(m.ϕ)
end
return
end

function POMDPs.gen(m::CompressedBeliefMDP, s::CompressedBeliefMDPState, a, rng::Random.AbstractRNG)
b = decode(m, s)
function POMDPs.gen(m::CompressedBeliefMDP, , a, rng::Random.AbstractRNG)
b = decode(m, )
bp, r = @gen(:sp, :r)(m.bmdp, b, a, rng)
sp = encode(m, bp)
return (sp=sp, r=r)
b̃p = encode(m, bp)
return (sp=b̃p, r=r)
end

# TODO: use macro forwarding
POMDPs.actions(m::CompressedBeliefMDP, s::CompressedBeliefMDPState) = actions(m.bmdp, decode(m, s))
# TODO: read about orthogonalized code on julia documetation
POMDPs.actions(m::CompressedBeliefMDP, b̃) = actions(m.bmdp, decode(m, b̃))
POMDPs.actions(m::CompressedBeliefMDP) = actions(m.bmdp)
POMDPs.isterminal(m::CompressedBeliefMDP, s::CompressedBeliefMDPState) = isterminal(m.bmdp, decode(m, s))
POMDPs.isterminal(m::CompressedBeliefMDP, ) = isterminal(m.bmdp, decode(m, ))
POMDPs.discount(m::CompressedBeliefMDP) = discount(m.bmdp)
POMDPs.initialstate(m::CompressedBeliefMDP) = encode(m, initialstate(m.bmdp))
POMDPs.actionindex(m::CompressedBeliefMDP, a) = actionindex(m.bmdp.pomdp, a) # TODO: figure out if can just wrap m.bmdp
POMDPs.actionindex(m::CompressedBeliefMDP, a) = actionindex(m.bmdp.pomdp, a)

POMDPs.convert_s(t::Type, s, m::CompressedBeliefMDP) = convert_s(t, s, m.bmdp.pomdp)
POMDPs.convert_s(t::Type{<:AbstractArray}, s::AbstractArray, m::CompressedBeliefMDP) = convert_s(t, s, m.bmdp.pomdp) # NOTE: this second implementation is b/c to get around a requirement from POMDPLinter

# TODO: maybe don't include sparsecat
ExplicitDistribution = Union{SparseCat, BoolDistribution, Deterministic, Uniform} # distributions w/ explicit PDFs from POMDPs.jl (https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/distributions/#Implemented-Distributions)
POMDPs.convert_s(::Type{<:AbstractArray}, s::ExplicitDistribution, m::POMDP) = [pdf(s, x) for x in states(m)]


# function POMDPs.convert_s(t::Type{V}, s, m::CompressedBeliefMDP) where V<:AbstractArray
# convert_s(t, s, m.bmdp.pomdp)
# end

# function POMDPs.convert_s(t::Type{S}, v::V, m::CompressedBeliefMDP) where {S, V<:AbstractArray}
# convert_s(t, v, m.bmdp.pomdp)
# end

# TODO: try to add issue to github for this since this is only to get around POMDPLinter





# POMDPs.convert_s(::Type{V}, s::CompressedBeliefMDPState, m::CompressedBeliefMDP) where V<:AbstractArray = convert_s(V, s.b̃, m)
function POMDPs.convert_s(::Type{V}, s::CompressedBeliefMDPState, m::CompressedBeliefMDP) where V<:AbstractArray
# TODO: come up w/ more elegant solution here
return convert_s(V, s.isa Matrix ? vec(s.b̃) : s.b̃, m)
end
POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v::AbstractArray, m::CompressedBeliefMDP) = CompressedBeliefMDPState(v)
POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v, m::CompressedBeliefMDP) = CompressedBeliefMDPState(convert_s(Vector, v, m.bmdp.pomdp))
# function POMDPs.convert_s(::Type{V}, s, m::CompressedBeliefMDP) where V<:AbstractArray
# # TODO: come up w/ more elegant solution here
# return convert_s(V, s.b̃ isa Matrix ? vec(s.b̃) : s.b̃, m)
# end
# # POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v::AbstractArray, m::CompressedBeliefMDP) = CompressedBeliefMDPState(v)
# POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v, m::CompressedBeliefMDP) = CompressedBeliefMDPState(convert_s(Vector, v, m.bmdp.pomdp))

# convenience methods
POMDPs.convert_s(::Type{<:AbstractArray}, s::DiscreteBelief, pomdp::POMDP) = s.b
POMDPs.convert_s(::Type{<:DiscreteBelief}, v, pomdp::POMDP) = DiscreteBelief(pomdp, vec(v))

ExplicitDistribution = Union{SparseCat, BoolDistribution, Deterministic, Uniform} # distributions w/ explicit PDF from POMDPs.jl (https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/distributions/#Implemented-Distributions)
POMDPs.convert_s(::Type{<:AbstractArray}, b::ExplicitDistribution, pomdp::POMDP) = [pdf(b, s) for s in states(pomdp)]

function POMDPs.convert_s(::Type{<:SparseCat}, vec, pomdp::POMDP)
@assert length(vec) == length(states(pomdp))
values = []
probabilities = []
for (s, p) in zip(states(pomdp), vec)
if p != 0
push!(values, s)
push!(probabilities, p)
end
end
dist = SparseCat(values, probabilities)
return dist
end

POMDPs.convert_s(::Type{<:BoolDistribution}, vec, pomdp::POMDP) = BoolDistribution(vec[1])
# TODO: add conversions to Uniform + Deterministic
POMDPs.convert_s(::Type{V}, s::DiscreteBelief, p::POMDP) where V<:AbstractArray = s.b
# POMDPs.convert_s(::Type{S}, v::V, p::POMDP) where {S, V<:AbstractArray} = DiscreteBelief(p, _process(convert(Vector{Float64}, vec(v)))) # TODO: is there Julian shorthand for type conversions?
# POMDPs.convert_s(::Type{<:AbstractArray}, s::DiscreteBelief, pomdp::POMDP) = s.b
# POMDPs.convert_s(::Type{<:DiscreteBelief}, v, problem::Union{POMDP, MDP) = DiscreteBelief(pomdp, vec(v))

# ExplicitDistribution = Union{SparseCat, BoolDistribution, Deterministic, Uniform} # distributions w/ explicit PDF from POMDPs.jl (https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/distributions/#Implemented-Distributions)
# POMDPs.convert_s(::Type{<:AbstractArray}, b::ExplicitDistribution, pomdp::POMDP) = [pdf(b, s) for s in states(pomdp)]

# function POMDPs.convert_s(::Type{<:SparseCat}, vec, pomdp::POMDP)
# @assert length(vec) == length(states(pomdp))
# values = []
# probabilities = []
# for (s, p) in zip(states(pomdp), vec)
# if p != 0
# push!(values, s)
# push!(probabilities, p)
# end
# end
# dist = SparseCat(values, probabilities)
# return dist
# end

# POMDPs.convert_s(::Type{<:BoolDistribution}, vec, pomdp::POMDP) = BoolDistribution(vec[1])
# # TODO: add conversions to Uniform + Deterministic
4 changes: 3 additions & 1 deletion src/compressors/compressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ function compress end
Decompress the compressed beliefs using method associated with compressor, and returns the reconstructed beliefs.
"""
function decompress end
function decompress end

# TODO: remove decompress and make compress a functor (https://docs.julialang.org/en/v1/manual/methods/#Note-on-Optional-and-keyword-Arguments)
7 changes: 2 additions & 5 deletions src/compressors/mv_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@ function fit!(compressor::MultivariateStatsCompressor{T}, beliefs) where T<:Mult
compressor.M = MultivariateStats.fit(T, beliefs'; maxoutdim=compressor.maxoutdim)
end

# TODO: is there a way to solve this w/ multiple dispatch? clean up
function compress(compressor::MultivariateStatsCompressor, beliefs)
# TODO: is there better way to do this?
if ndims(beliefs) == 2
return MultivariateStats.predict(compressor.M, beliefs')'
else
return MultivariateStats.predict(compressor.M, beliefs)
end
return ndims(beliefs) == 2 ? predict(compressor.M, beliefs')' : vec(predict(compressor.M, beliefs))
end

decompress(compressor::MultivariateStatsCompressor, compressed) = MultivariateStats.reconstruct(compressor.M, compressed)
Expand Down
10 changes: 10 additions & 0 deletions src/sampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function sample(pomdp::POMDP, policy::Policy, updater::Updater, n::Integer)
mdp = GenerativeBeliefMDP(pomdp, updater)
iter = stepthrough(mdp, policy, "s"; max_steps=n)
B = collect(Iterators.take(Iterators.cycle(iter), n))
return unique!(B)
end

function sample(pomdp::POMDP, policy::ExplorationPolicy, updater::Updater, n::Integer)
# TODO:
end
10 changes: 0 additions & 10 deletions src/samplers/sampler.jl

This file was deleted.

39 changes: 0 additions & 39 deletions src/samplers/utils.jl

This file was deleted.

Loading

0 comments on commit f733630

Please sign in to comment.