From f733630fc06eeb113f50284faa229b98ec9d2525 Mon Sep 17 00:00:00 2001 From: Logan Mondal Bhamidipaty <76822456+FlyingWorkshop@users.noreply.github.com> Date: Sat, 23 Mar 2024 15:55:35 -0700 Subject: [PATCH] major clean up! --- Project.toml | 1 + images/baby_benchmark.svg | 50 ---------- paper.md | 4 +- src/CompressedBeliefMDPs.jl | 19 ++-- src/cbmdp.jl | 182 ++++++++++++++-------------------- src/compressors/compressor.jl | 4 +- src/compressors/mv_stats.jl | 7 +- src/sampler.jl | 10 ++ src/samplers/sampler.jl | 10 -- src/samplers/utils.jl | 39 -------- src/solver.jl | 94 ++++++------------ 11 files changed, 135 insertions(+), 285 deletions(-) delete mode 100644 images/baby_benchmark.svg create mode 100644 src/sampler.jl delete mode 100644 src/samplers/sampler.jl delete mode 100644 src/samplers/utils.jl diff --git a/Project.toml b/Project.toml index 84894c8..831536f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Logan-Mondal-Bhamidipaty"] version = "1.0.0-DEV" [deps] +Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/images/baby_benchmark.svg b/images/baby_benchmark.svg deleted file mode 100644 index 6e8577b..0000000 --- a/images/baby_benchmark.svg +++ /dev/null @@ -1,50 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/paper.md b/paper.md index 9d2b58e..fe3192e 100644 --- a/paper.md +++ b/paper.md @@ -65,15 +65,13 @@ v = value(approx_policy, s) a = action(approx_policy, s) ``` -![We see that that the compressed solver recovers a similar value function to SARSOP [@SARSOP].](./images/baby_benchmark.svg){height="200pt"} - ## Function Approximators CompressedBeliefMDPs.jl is compatible with any `LocalFunctionApproximator`. It supports grid interpolations [@grid] through [GridInterpolations.jl](https://github.com/sisl/GridInterpolations.jl) and $k$-nearest neighbors [@kNN] through [NearestNeighbors.jl](https://github.com/KristofferC/NearestNeighbors.jl). For more details, see [LocalFunctionApproximation.jl](https://github.com/sisl/LocalFunctionApproximation.jl) ## Compressors -CompressedBeliefMDPs.jl provides several wrappers for commonly used compressors. Through [MultiVariateStats.jl](https://juliastats.org/MultivariateStats.jl/stable/), we include PCA [@PCA], kernel PCA [@kernelPCA], and probabilistic PCA [@PPCA]. Through a forthcoming package `ExpFamilyPCA.jl`, we also include exponential family PCA [@EPCA]. This is more general than the Poisson exponential family PCA in @Roy: ExpFamilyPCA.jl supports compression objectives induced from _any_ convex function, not just the Poisson link function. +CompressedBeliefMDPs.jl provides several wrappers for commonly used compressors. Through [MultiVariateStats.jl](https://juliastats.org/MultivariateStats.jl/stable/), we include PCA [@PCA], kernel PCA [@kernelPCA], and probabilistic PCA [@PPCA]. # Acknowledgements diff --git a/src/CompressedBeliefMDPs.jl b/src/CompressedBeliefMDPs.jl index 2f30ea8..759872a 100644 --- a/src/CompressedBeliefMDPs.jl +++ b/src/CompressedBeliefMDPs.jl @@ -4,8 +4,13 @@ using Infiltrator using POMDPs using POMDPTools +using LocalApproximationValueIteration +using LocalFunctionApproximation import Lazy: @forward +using Bijections +using NearestNeighbors +using StaticArrays using LinearAlgebra using Random @@ -27,22 +32,16 @@ include("compressors/compressor.jl") include("compressors/mv_stats.jl") export - Sampler, - sample, - # DiscreteEpsGreedySampler, - # DiscreteSoftmaxSampler, - DiscreteRandomSampler -include("samplers/sampler.jl") -include("samplers/utils.jl") + sample +include("sampler.jl") export CompressedBeliefMDP, - CompressedBeliefMDPState include("cbmdp.jl") export - CompressedSolver, - CompressedSolverPolicy + CompressedBeliefSolver, + CompressedBeliefPolicy, solve include("solver.jl") diff --git a/src/cbmdp.jl b/src/cbmdp.jl index dd8131a..4f5b994 100644 --- a/src/cbmdp.jl +++ b/src/cbmdp.jl @@ -1,134 +1,104 @@ -# TODO: add particle filters --> AbstractParticleBelief https://juliapomdp.github.io/ParticleFilters.jl/latest/beliefs/ - - -""" - struct CompressedBeliefMDP{B, A} <: MDP{B, A} - -A struct representing a Markov Decision Process (MDP) with compressed beliefs. - -# Fields -- `bmdp::GenerativeBeliefMDP`: The underlying GenerativeBeliefMDP that defines the original MDP. -- `compressor::Compressor`: The compressor used to compress beliefs in the MDP. - -## Type Parameters -- `B`: Type of the belief space. -- `A`: Type of the action space. - -## Example -```julia -pomdp = TigerPOMDP() -updater = DiscreteUpdater(pomdp) -compressor = PCACompressor(1) -generative_mdp = GenerativeBeliefMDP(pomdp, updater) -mdp = CompressedBeliefMDP(generative_mdp, compressor) -``` -""" struct CompressedBeliefMDP{B, A} <: MDP{B, A} bmdp::GenerativeBeliefMDP compressor::Compressor + ϕ::Bijection # ϕ: belief ↦ compress(compressor, belief); NOTE: While compressions aren't usually injective, we cache compressed beliefs on a first-come, first-served basis, so the *cache* is effectively bijective. end -""" - struct CompressedBeliefMDPState - -A struct representing the state of a CompressedBeliefMDP. - -# Fields -- `b̃::AbstractArray{<:Real}`: Compressed belief vector. -""" -struct CompressedBeliefMDPState - b̃::AbstractArray{<:Real} -end # TODO: merge docstring of constructor w/ docstring for the struct proper -""" - CompressedBeliefMDP(pomdp::POMDP, updater::Updater, compressor::Compressor) - -Create a CompressedBeliefMDP based on a given POMDP, updater, and compressor. - -This function initializes a `GenerativeBeliefMDP` using the provided `pomdp` and `updater`. -It then constructs a `CompressedBeliefMDP` with the specified `compressor`. - -# Arguments -- `pomdp::POMDP`: The original partially observable Markov decision process (POMDP). -- `updater::Updater`: The belief updater used in the GenerativeBeliefMDP. -- `compressor::Compressor`: The compressor used to compress beliefs in the MDP. - -# Returns -- `CompressedBeliefMDP{CompressedBeliefMDPState, actiontype(bmdp)}`: The compressed belief MDP. - -## Example -```julia -pomdp = MyPOMDP() # replace with your specific POMDP type -updater = MyUpdater() # replace with your specific updater type -compressor = MyCompressor() # replace with your specific compressor type - -mdp = CompressedBeliefMDP(pomdp, updater, compressor) -``` -""" function CompressedBeliefMDP(pomdp::POMDP, updater::Updater, compressor::Compressor) + # Hack to determine typeof(b̃) bmdp = GenerativeBeliefMDP(pomdp, updater) - return CompressedBeliefMDP{CompressedBeliefMDPState, actiontype(bmdp)}(bmdp, compressor) + b = initialstate(bmdp).val + b̃ = compress(compressor, convert_s(AbstractVector{Float64}, b, bmdp.pomdp)) + B = typeof(b) + B̃ = typeof(b̃) + ϕ = Bijection{B, B̃}() + return CompressedBeliefMDP{B̃, actiontype(bmdp)}(bmdp, compressor, ϕ) end - -function decode(m::CompressedBeliefMDP, s::CompressedBeliefMDPState) - b = decompress(m.compressor, s.b̃) - b = normalize(abs.(b), 1) # make valid probability distribution - b = convert_s(statetype(m.bmdp), b, m.bmdp.pomdp) +function decode(m::CompressedBeliefMDP, b̃) + b = m.ϕ(b̃) return b end function encode(m::CompressedBeliefMDP, b) - b = convert_s(AbstractArray, b, m.bmdp.pomdp) - b̃ = compress(m.compressor, b) - s = CompressedBeliefMDPState(b̃) - return s + b = convert_s(AbstractVector{Float64}, b, m) + b̃ = get!(m.ϕ, b) do + b = convert_s(AbstractArray{Float64}, b, m) # TODO: not sure if I need a `let b = ...` here + compress(m.compressor, b) # NOTE: compress is only called if b ∉ domain(m.ϕ) + end + return b̃ end -function POMDPs.gen(m::CompressedBeliefMDP, s::CompressedBeliefMDPState, a, rng::Random.AbstractRNG) - b = decode(m, s) +function POMDPs.gen(m::CompressedBeliefMDP, b̃, a, rng::Random.AbstractRNG) + b = decode(m, b̃) bp, r = @gen(:sp, :r)(m.bmdp, b, a, rng) - sp = encode(m, bp) - return (sp=sp, r=r) + b̃p = encode(m, bp) + return (sp=b̃p, r=r) end # TODO: use macro forwarding -POMDPs.actions(m::CompressedBeliefMDP, s::CompressedBeliefMDPState) = actions(m.bmdp, decode(m, s)) +# TODO: read about orthogonalized code on julia documetation +POMDPs.actions(m::CompressedBeliefMDP, b̃) = actions(m.bmdp, decode(m, b̃)) POMDPs.actions(m::CompressedBeliefMDP) = actions(m.bmdp) -POMDPs.isterminal(m::CompressedBeliefMDP, s::CompressedBeliefMDPState) = isterminal(m.bmdp, decode(m, s)) +POMDPs.isterminal(m::CompressedBeliefMDP, b̃) = isterminal(m.bmdp, decode(m, b̃)) POMDPs.discount(m::CompressedBeliefMDP) = discount(m.bmdp) POMDPs.initialstate(m::CompressedBeliefMDP) = encode(m, initialstate(m.bmdp)) -POMDPs.actionindex(m::CompressedBeliefMDP, a) = actionindex(m.bmdp.pomdp, a) # TODO: figure out if can just wrap m.bmdp +POMDPs.actionindex(m::CompressedBeliefMDP, a) = actionindex(m.bmdp.pomdp, a) + +POMDPs.convert_s(t::Type, s, m::CompressedBeliefMDP) = convert_s(t, s, m.bmdp.pomdp) +POMDPs.convert_s(t::Type{<:AbstractArray}, s::AbstractArray, m::CompressedBeliefMDP) = convert_s(t, s, m.bmdp.pomdp) # NOTE: this second implementation is b/c to get around a requirement from POMDPLinter + +# TODO: maybe don't include sparsecat +ExplicitDistribution = Union{SparseCat, BoolDistribution, Deterministic, Uniform} # distributions w/ explicit PDFs from POMDPs.jl (https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/distributions/#Implemented-Distributions) +POMDPs.convert_s(::Type{<:AbstractArray}, s::ExplicitDistribution, m::POMDP) = [pdf(s, x) for x in states(m)] + + +# function POMDPs.convert_s(t::Type{V}, s, m::CompressedBeliefMDP) where V<:AbstractArray +# convert_s(t, s, m.bmdp.pomdp) +# end + +# function POMDPs.convert_s(t::Type{S}, v::V, m::CompressedBeliefMDP) where {S, V<:AbstractArray} +# convert_s(t, v, m.bmdp.pomdp) +# end + +# TODO: try to add issue to github for this since this is only to get around POMDPLinter + + + + # POMDPs.convert_s(::Type{V}, s::CompressedBeliefMDPState, m::CompressedBeliefMDP) where V<:AbstractArray = convert_s(V, s.b̃, m) -function POMDPs.convert_s(::Type{V}, s::CompressedBeliefMDPState, m::CompressedBeliefMDP) where V<:AbstractArray - # TODO: come up w/ more elegant solution here - return convert_s(V, s.b̃ isa Matrix ? vec(s.b̃) : s.b̃, m) -end -POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v::AbstractArray, m::CompressedBeliefMDP) = CompressedBeliefMDPState(v) -POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v, m::CompressedBeliefMDP) = CompressedBeliefMDPState(convert_s(Vector, v, m.bmdp.pomdp)) +# function POMDPs.convert_s(::Type{V}, s, m::CompressedBeliefMDP) where V<:AbstractArray +# # TODO: come up w/ more elegant solution here +# return convert_s(V, s.b̃ isa Matrix ? vec(s.b̃) : s.b̃, m) +# end +# # POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v::AbstractArray, m::CompressedBeliefMDP) = CompressedBeliefMDPState(v) +# POMDPs.convert_s(::Type{CompressedBeliefMDPState}, v, m::CompressedBeliefMDP) = CompressedBeliefMDPState(convert_s(Vector, v, m.bmdp.pomdp)) # convenience methods -POMDPs.convert_s(::Type{<:AbstractArray}, s::DiscreteBelief, pomdp::POMDP) = s.b -POMDPs.convert_s(::Type{<:DiscreteBelief}, v, pomdp::POMDP) = DiscreteBelief(pomdp, vec(v)) - -ExplicitDistribution = Union{SparseCat, BoolDistribution, Deterministic, Uniform} # distributions w/ explicit PDF from POMDPs.jl (https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/distributions/#Implemented-Distributions) -POMDPs.convert_s(::Type{<:AbstractArray}, b::ExplicitDistribution, pomdp::POMDP) = [pdf(b, s) for s in states(pomdp)] - -function POMDPs.convert_s(::Type{<:SparseCat}, vec, pomdp::POMDP) - @assert length(vec) == length(states(pomdp)) - values = [] - probabilities = [] - for (s, p) in zip(states(pomdp), vec) - if p != 0 - push!(values, s) - push!(probabilities, p) - end - end - dist = SparseCat(values, probabilities) - return dist -end - -POMDPs.convert_s(::Type{<:BoolDistribution}, vec, pomdp::POMDP) = BoolDistribution(vec[1]) -# TODO: add conversions to Uniform + Deterministic \ No newline at end of file +POMDPs.convert_s(::Type{V}, s::DiscreteBelief, p::POMDP) where V<:AbstractArray = s.b +# POMDPs.convert_s(::Type{S}, v::V, p::POMDP) where {S, V<:AbstractArray} = DiscreteBelief(p, _process(convert(Vector{Float64}, vec(v)))) # TODO: is there Julian shorthand for type conversions? +# POMDPs.convert_s(::Type{<:AbstractArray}, s::DiscreteBelief, pomdp::POMDP) = s.b +# POMDPs.convert_s(::Type{<:DiscreteBelief}, v, problem::Union{POMDP, MDP) = DiscreteBelief(pomdp, vec(v)) + +# ExplicitDistribution = Union{SparseCat, BoolDistribution, Deterministic, Uniform} # distributions w/ explicit PDF from POMDPs.jl (https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/distributions/#Implemented-Distributions) +# POMDPs.convert_s(::Type{<:AbstractArray}, b::ExplicitDistribution, pomdp::POMDP) = [pdf(b, s) for s in states(pomdp)] + +# function POMDPs.convert_s(::Type{<:SparseCat}, vec, pomdp::POMDP) +# @assert length(vec) == length(states(pomdp)) +# values = [] +# probabilities = [] +# for (s, p) in zip(states(pomdp), vec) +# if p != 0 +# push!(values, s) +# push!(probabilities, p) +# end +# end +# dist = SparseCat(values, probabilities) +# return dist +# end + +# POMDPs.convert_s(::Type{<:BoolDistribution}, vec, pomdp::POMDP) = BoolDistribution(vec[1]) +# # TODO: add conversions to Uniform + Deterministic \ No newline at end of file diff --git a/src/compressors/compressor.jl b/src/compressors/compressor.jl index e5e6d89..95a212c 100644 --- a/src/compressors/compressor.jl +++ b/src/compressors/compressor.jl @@ -25,4 +25,6 @@ function compress end Decompress the compressed beliefs using method associated with compressor, and returns the reconstructed beliefs. """ -function decompress end \ No newline at end of file +function decompress end + +# TODO: remove decompress and make compress a functor (https://docs.julialang.org/en/v1/manual/methods/#Note-on-Optional-and-keyword-Arguments) \ No newline at end of file diff --git a/src/compressors/mv_stats.jl b/src/compressors/mv_stats.jl index 038db88..a4f019e 100644 --- a/src/compressors/mv_stats.jl +++ b/src/compressors/mv_stats.jl @@ -9,13 +9,10 @@ function fit!(compressor::MultivariateStatsCompressor{T}, beliefs) where T<:Mult compressor.M = MultivariateStats.fit(T, beliefs'; maxoutdim=compressor.maxoutdim) end +# TODO: is there a way to solve this w/ multiple dispatch? clean up function compress(compressor::MultivariateStatsCompressor, beliefs) # TODO: is there better way to do this? - if ndims(beliefs) == 2 - return MultivariateStats.predict(compressor.M, beliefs')' - else - return MultivariateStats.predict(compressor.M, beliefs) - end + return ndims(beliefs) == 2 ? predict(compressor.M, beliefs')' : vec(predict(compressor.M, beliefs)) end decompress(compressor::MultivariateStatsCompressor, compressed) = MultivariateStats.reconstruct(compressor.M, compressed) diff --git a/src/sampler.jl b/src/sampler.jl new file mode 100644 index 0000000..b2bc9fa --- /dev/null +++ b/src/sampler.jl @@ -0,0 +1,10 @@ +function sample(pomdp::POMDP, policy::Policy, updater::Updater, n::Integer) + mdp = GenerativeBeliefMDP(pomdp, updater) + iter = stepthrough(mdp, policy, "s"; max_steps=n) + B = collect(Iterators.take(Iterators.cycle(iter), n)) + return unique!(B) +end + +function sample(pomdp::POMDP, policy::ExplorationPolicy, updater::Updater, n::Integer) + # TODO: +end \ No newline at end of file diff --git a/src/samplers/sampler.jl b/src/samplers/sampler.jl deleted file mode 100644 index 1911caf..0000000 --- a/src/samplers/sampler.jl +++ /dev/null @@ -1,10 +0,0 @@ -abstract type Sampler end - -""" - sample(sampler::Sampler, pomdp::POMDP; n_samples::Integer=100) - -Return a matrix of beliefs sampled from pomdp. -""" -function sample end - -# TODO: rewrite this based on Mykel's email \ No newline at end of file diff --git a/src/samplers/utils.jl b/src/samplers/utils.jl deleted file mode 100644 index eb35d2c..0000000 --- a/src/samplers/utils.jl +++ /dev/null @@ -1,39 +0,0 @@ -# TODO: replace w/ take (get GPT to simplify this loop) -function sample(pomdp::POMDP, n_samples::Integer; policy::Policy=RandomPolicy(pomdp), updater::Updater=DiscreteUpdater(pomdp)) - B = [] - while true - for dist in stepthrough(pomdp, policy, updater, "b", max_steps=n_samples) - push!(B, dist.b) - if length(B) == n_samples - return hcat(B...)' # convert to n_samples x n_states matrix - end - end - end -end - - -struct BaseSampler <: Sampler - policy::Policy - updater::Updater -end - -sample(sampler::BaseSampler, pomdp::POMDP; n_samples::Integer=100) = sample(pomdp, n_samples; sampler.policy, sampler.updater) -DiscreteSampler(pomdp::POMDP, policy::Policy) = BaseSampler(policy, DiscreteUpdater(pomdp)) - -function DiscreteEpsGreedySampler(pomdp::POMDP, eps; rng::AbstractRNG=Random.GLOBAL_RNG) - policy = EpsGreedyPolicy(pomdp, eps; rng=rng) - return DiscreteSampler(pomdp, policy) -end - -# TODO: figure out a better default schedule -# TODO: replace this w/ custom eps greedy policy -function DiscreteSoftmaxSampler(pomdp::POMDP, temperature; rng::AbstractRNG=Random.GLOBAL_RNG) - policy = SoftmaxPolicy(pomdp, temperature; rng=rng) - return DiscreteSampler(pomdp, policy) -end - -function DiscreteRandomSampler(pomdp::POMDP; rng::AbstractRNG=Random.GLOBAL_RNG) - updater = DiscreteUpdater(pomdp) - policy = RandomPolicy(pomdp; rng=rng, updater=updater) - return BaseSampler(policy, updater) -end \ No newline at end of file diff --git a/src/solver.jl b/src/solver.jl index 47d9704..e778cd0 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -1,69 +1,41 @@ -using NearestNeighbors -using StaticArrays -using LocalFunctionApproximation -using LocalApproximationValueIteration: LocalApproximationValueIterationSolver, LocalApproximationValueIterationPolicy - -# TODO: add references to the appropriate local approx libraries in the docstring -struct CompressedSolver <: POMDPs.Solver - sampler::Sampler - compressor::Compressor - approximator::LocalFunctionApproximator +struct CompressedBeliefSolver <: Solver + explorer::Union{Policy, ExplorationPolicy} updater::Updater - base_solver::LocalApproximationValueIterationSolver -end - -struct CompressedSolverPolicy <: Policy - m::CompressedBeliefMDP - base_policy::LocalApproximationValueIterationPolicy -end - -# TODO: use macro forwarding -function POMDPs.action(p::CompressedSolverPolicy, b) - s = encode(p.m, b) - return action(p.base_policy, s) -end - -function POMDPs.value(p::CompressedSolverPolicy, b) - s = encode(p.m, b) - return value(p.base_policy, s) + compressor::Compressor + base_solver::Solver + n::Integer end -POMDPs.updater(policy::CompressedSolverPolicy) = policy.m.bmdp.updater - -function CompressedSolver( - pomdp::POMDP, - sampler::BaseSampler, - compressor::Compressor; - n_samples::Integer=100, - k::Integer=1, - verbose=false, - n_generative_samples=500, - max_iterations=1000 +function CompressedBeliefSolver( + explorer::Union{Policy, ExplorationPolicy}, + updater::Updater, + compressor::Compressor, + base_solver::Solver; + n=100 ) - # sample and compress beliefs - B = sample(sampler, pomdp; n_samples=n_samples) - fit!(compressor, B) - B̃ = compress(compressor, B) - - # make function approximator - data = [SVector(row...) for row in eachrow(B̃)] - tree = KDTree(data) - approximator = LocalNNFunctionApproximator(tree, data, k) - - base_solver = LocalApproximationValueIterationSolver( - approximator; - verbose=verbose, - max_iterations=max_iterations, - is_mdp_generative=true, - n_generative_samples=n_generative_samples - ) - - return CompressedSolver(sampler, compressor, approximator, sampler.updater, base_solver) + return CompressedBeliefSolver(explorer, updater, compressor, base_solver, n) end -function POMDPs.solve(solver::CompressedSolver, pomdp::POMDP) - cbmdp = CompressedBeliefMDP(pomdp, solver.updater, solver.compressor) - approx_policy = solve(solver.base_solver, cbmdp) - return CompressedSolverPolicy(cbmdp, approx_policy) +# TODO: make compressed solver that infers everything +# TODO: make compressed solver that uses local FA solver + +struct CompressedBeliefPolicy <: Policy + m::CompressedBeliefMDP + base_policy::Policy end +POMDPs.action(p::CompressedBeliefPolicy, b) = action(p.base_policy, encode(m, b)) +POMDPs.value(p::CompressedBeliefPolicy, b) = value(p.base_policy, encode(m, b)) +POMDPs.updater(p::CompressedBeliefPolicy) = p.m.bmdp.updater + +function POMDPs.solve(solver::CompressedBeliefSolver, pomdp::POMDP) + B = sample(pomdp, solver.explorer, solver.updater, solver.n) + B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' + fit!(solver.compressor, B_numerical) + B̃ = compress(solver.compressor, B_numerical) + m = CompressedBeliefMDP(pomdp, solver.updater, solver.compressor) + ϕ = Dict(unique(t->t[2], zip(B, eachrow(B̃)))) + merge!(m.ϕ, ϕ) # update compression cache + base_policy = solve(solver.base_solver, m) + return CompressedBeliefPolicy(m, base_policy) +end \ No newline at end of file