Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the definition of AbstractTraces #14

Merged
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
coverage/
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml

.DS_Store
5 changes: 5 additions & 0 deletions src/LastDimSlices.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export LastDimSlices

using MacroTools: @forward

# See also https://github.com/JuliaLang/julia/pull/32310
5 changes: 2 additions & 3 deletions src/Trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module Trajectories

include("samplers.jl")
include("controlers.jl")
include("traces.jl")
include("episodes.jl")
include("samplers.jl")
include("controlers.jl")
include("trajectory.jl")
include("rendering.jl")
include("common/common.jl")

end
37 changes: 2 additions & 35 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
export CircularArraySARTTraces

const CircularArraySARTTraces = Traces{
SART,
<:Tuple{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -23,32 +12,10 @@ function CircularArraySARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_action=t[:state][inds′]
) |> s.transformer
end

function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:action], x[:action])
end
46 changes: 4 additions & 42 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
export CircularArraySLARTTraces

const CircularArraySLARTTraces = Traces{
SLART,
<:Tuple{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -26,37 +14,11 @@ function CircularArraySLARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:legal_actions_mask, :next_legal_actions_mask)}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function sample(s::BatchSampler, t::CircularArraySLARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
legal_actions_mask=t[:legal_actions_mask][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_legal_actions_mask=t[:legal_actions_mask][inds′],
next_action=t[:state][inds′]
) |> s.transformer
end

function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:legal_actions_mask])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:legal_actions_mask], x[:legal_actions_mask])
push!(t[:action], x[:action])
end
end
8 changes: 0 additions & 8 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
using CircularArrayBuffers

const SA = (:state, :action)
const SLA = (:state, :legal_actions_mask, :action)
const RT = (:reward, :terminal)
const SART = (:state, :action, :reward, :terminal)
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
include("CircularArraySLARTTraces.jl")
53 changes: 19 additions & 34 deletions src/episodes.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,45 @@
export Episode, Episodes

using MLUtils: batch

"""
Episode(traces)

An `Episode` is a wrapper around [`Traces`](@ref). You can use `(e::Episode)[]`
to check/update whether the episode reaches a terminal or not.
"""
struct Episode{T}
struct Episode{T,E} <: AbstractVector{E}
traces::T
is_done::Ref{Bool}
is_terminated::Ref{Bool}
end

Base.getindex(e::Episode, s::Symbol) = getindex(e.traces, s)
Base.keys(e::Episode) = keys(e.traces)
Base.getindex(e::Episode, I) = getindex(e.traces, I)
Base.getindex(e::Episode) = getindex(e.is_terminated)
Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_terminated, x)

Base.getindex(e::Episode) = getindex(e.is_done)
Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_done, x)
Base.size(e::Episode) = size(e.traces)

Base.length(e::Episode) = length(e.traces)
Episode(t::T) where {T<:AbstractTraces} = Episode{T,eltype(t)}(t, Ref(false))

Episode(t::Traces) = Episode(t, Ref(false))

function Base.push!(t::Episode, x)
if t.is_done[]
throw(ArgumentError("The episode is already flagged as done!"))
else
push!(t.traces, x)
end
end

function Base.append!(t::Episode, x)
if t.is_done[]
throw(ArgumentError("The episode is already flagged as done!"))
else
append!(t.traces, x)
for f in (:push!, :pushfirst!, :append!, :prepend!)
@eval function Base.$f(t::Episode, x)
if t.is_terminated[]
throw(ArgumentError("The episode is already flagged as done!"))
else
$f(t.traces, x)
end
end
end

function Base.pop!(t::Episode)
pop!(t.traces)
t.is_done[] = false
t.is_terminated[] = false
end

Base.popfirst!(t::Episode) = popfirst!(t.traces)

function Base.empty!(t::Episode)
empty!(t.traces)
t.is_done[] = false
t.is_terminated[] = false
end

#####
Expand All @@ -58,13 +49,14 @@ end

A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an [`Episode`](@ref).
"""
struct Episodes
struct Episodes <: AbstractVector{Episode}
init::Any
episodes::Vector{Episode}
inds::Vector{Tuple{Int,Int}}
end

Base.length(e::Episodes) = length(e.inds)
Base.size(e::Episodes) = size(e.inds)
Base.getindex(e::Episodes, I) = getindex(e.episodes, I)

function Base.push!(e::Episodes, x::Episode)
push!(e.episodes, x)
Expand Down Expand Up @@ -98,10 +90,3 @@ function Base.append!(e::Episodes, x)
push!(e.inds, (lengthe.episodes, i))
end
end
findmyway marked this conversation as resolved.
Show resolved Hide resolved

##

function sample(s::BatchSampler, e::Episodes)
inds = rand(s.rng, 1:length(t), s.batch_size)
batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer
end
135 changes: 0 additions & 135 deletions src/rendering.jl

This file was deleted.

Loading