-
Notifications
You must be signed in to change notification settings - Fork 105
/
generative.jl
191 lines (148 loc) · 7.51 KB
/
generative.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
gen(...)
Sample from generative model of a POMDP or MDP.
In most cases solver and simulator writers should use the `@gen` macro. Problem writers may wish to implement one or more new methods of the function for their problem.
There are three versions of the function:
- The most convenient version to implement is gen(m::Union{MDP,POMDP}, s, a, rng::AbstractRNG), which returns a `NamedTuple`.
- Defining behavior for and sampling from individual nodes of the dynamic decision network can be accomplished using the version with a `DDNNode` argument.
- A version with a `DDNOut` argument is provided by the compiler to sample multiple nodes at once.
See below for detailed documentation for each type.
---
gen(m::Union{MDP,POMDP}, s, a, rng::AbstractRNG)
Convenience function for implementing the entire MDP/POMDP generative model in one function by returning a `NamedTuple`.
The `NamedTuple` version of `gen` is the most convenient for problem writers to implement. However, it should *never* be used directly by solvers or simulators. Instead solvers and simulators should use the version with a `DDNOut` first argument.
# Arguments
- `m`: an `MDP` or `POMDP` model
- `s`: the current state
- `a`: the action
- `rng`: a random number generator (Typically a `MersenneTwister`)
# Return
The function should return a [`NamedTuple`](https://docs.julialang.org/en/v1/base/base/#Core.NamedTuple). Typically, this `NamedTuple` will be `(sp=<next state>, r=<reward>)` for an `MDP` or `(sp=<next state>, o=<observation>, r=<reward>) for a `POMDP`.
---
gen(v::DDNNode{name}, m::Union{MDP,POMDP}, depargs..., rng::AbstractRNG)
Sample a value from a node in the dynamic decision network.
These functions will be used within gen(::DDNOut, ...) to sample values for all outputs and their dependencies. They may be implemented directly by a problem-writer if they wish to implement a generative model for a particular node in the dynamic decision network, and may be called in solvers to sample a value for a particular node.
# Arguments
- `v::DDNNode{name}`: which DDN node the function should sample from.
- `depargs`: values for all the dependent nodes. Dependencies are determined by `deps(DDNStructure(m), name)`.
- `rng`: a random number generator (Typically a `MersenneTwister`)
# Return
A sampled value from the specified node.
# Examples
Let `m` be a `POMDP`, `s` and `sp` be states of `m`, `a` be an action of `m`, and `rng` be an `AbstractRNG`.
- `gen(DDNNode(:sp), m, s, a, rng)` returns the next state.
- `gen(DDNNode(:o), m, s, a, sp, rng)` returns the observation given the previous state, action, and new state.
---
gen(t::DDNOut{X}, m::Union{MDP,POMDP}, s, a, rng::AbstractRNG) where X
Sample values from several nodes in the dynamic decision network. X is a symbol or tuple of symbols indicating which nodes to output.
An implementation of this method is automatically provided by POMDPs.jl. Solvers and simulators should use this version. Problem writers may implement it directly in special cases (see the POMDPs.jl documentation for more information).
# Arguments
- `t::DDNOut`: which DDN nodes the function should sample from.
- `m`: an `MDP` or `POMDP` model
- `s`: the current state
- `a`: the action
- `rng`: a random number generator (Typically a `MersenneTwister`)
# Return
If the `DDNOut` parameter, `X`, is a symbol, return a value sample from the corresponding node. If `X` is a tuple of symbols, return a `Tuple` of values sampled from the specified nodes.
# Examples
Let `m` be an `MDP` or `POMDP`, `s` be a state of `m`, `a` be an action of `m`, and `rng` be an `AbstractRNG`.
- `gen(DDNOut(:sp, :r), m, s, a, rng)` returns a `Tuple` containing the next state and reward.
- `gen(DDNOut(:sp, :o, :r), m, s, a, rng)` returns a `Tuple` containing the next state, observation, and reward.
- `gen(DDNOut(:sp), m, s, a, rng)` returns the next state.
"""
function gen end
"""
initialstate(m::Union{POMDP,MDP}, rng::AbstractRNG)
Return a sampled initial state for the problem `m`.
Usually the initial state is sampled from an initial state distribution. The random number generator `rng` should be used to draw this sample (e.g. use `rand(rng)` instead of `rand()`).
"""
function initialstate end
function implemented(f::typeof(initialstate), TT::Type)
if !hasmethod(f, TT)
return false
end
m = which(f, TT)
if m.module == POMDPs && !implemented(initialstate_distribution, Tuple{TT.parameters[1]})
return false
else
return true
end
end
@generated function initialstate(p::Union{POMDP,MDP}, rng)
impl = quote
d = initialstate_distribution(p)
return rand(rng, d)
end
# it is technically illegal to call this within the generated function
if implemented(initialstate_distribution, Tuple{p})
return impl
else
return quote
try
$impl # trick to get the compiler to insert the right backedges
catch
throw(MethodError(initialstate, (p, rng)))
end
end
end
end
"""
initialobs(m::POMDP, s, rng::AbstractRNG)
Return a sampled initial observation for the problem `m` and state `s`.
This function is only used in cases where the policy expects an initial observation rather than an initial belief, e.g. in a reinforcement learning setting. It is not used in a standard POMDP simulation.
By default, it will fall back to `observation(m, s)`. The random number generator `rng` should be used to draw this sample (e.g. use `rand(rng)` instead of `rand()`).
"""
function initialobs end
function implemented(f::typeof(initialobs), TT::Type)
if !hasmethod(f, TT)
return false
end
m = which(f, TT)
if m.module == POMDPs && !implemented(observation, Tuple{TT.parameters[1:2]...})
return false
else
return true
end
end
@generated function initialobs(m::POMDP, s, rng)
impl = quote
d = observation(m, s)
return rand(rng, d)
end
# it is technically illegal to call this within the generated function
if implemented(observation, Tuple{m, s})
return impl
else
return quote
try
$impl # trick to get the compiler to insert the right backedges
catch
throw(MethodError(initialobs, (m, s, rng)))
end
end
end
end
"""
@gen(X)(m, s, a)
@gen(X)(m, s, a, rng::AbstractRNG)
Call the generative model for a (PO)MDP `m`; Sample values from several nodes in the dynamic decision network. X is one or more symbols indicating which nodes to output.
Solvers and simulators should usually call this rather than the `gen` function. Problem writers should implement methods of the `gen` function.
# Arguments
- `m`: an `MDP` or `POMDP` model
- `s`: the current state
- `a`: the action
- `rng`: a random number generator (Typically a `MersenneTwister`)
# Return
If `X`, is a symbol, return a value sample from the corresponding node. If `X` is several symbols, return a `Tuple` of values sampled from the specified nodes.
# Examples
Let `m` be an `MDP` or `POMDP`, `s` be a state of `m`, `a` be an action of `m`, and `rng` be an `AbstractRNG`.
- `@gen(:sp, :r)(m, s, a, rng)` returns a `Tuple` containing the next state and reward.
- `@gen(:sp, :o, :r)(m, s, a, rng)` returns a `Tuple` containing the next state, observation, and reward.
- `@gen(:sp)(m, s, a, rng)` returns the next state.
"""
macro gen(symbols...)
quote
# this should be an anonymous function, but there is a bug
f(m, s, a, rng=Random.GLOBAL_RNG) = gen(DDNOut($(symbols...)), m, s, a, rng)
end
end