Skip to content

Commit

Permalink
Merge pull request #72 from dharux/sartsa-fix
Browse files Browse the repository at this point in the history
Fix issues with SARTSATraces
  • Loading branch information
jeremiahpslewis authored May 10, 2024
2 parents b23dd43 + 29a6a3e commit de01fb3
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.4"
version = "0.4.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
6 changes: 3 additions & 3 deletions src/common/CircularArraySARTSATraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ function CircularArraySARTSATraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1),
)
end

Expand Down
23 changes: 22 additions & 1 deletion src/common/CircularPrioritizedTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ end
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
new_names = (:key, :priority, names...)
new_Ts = Tuple{Int,Float32,Ts.parameters...}
c = capacity(traces)
if traces isa CircularArraySARTSATraces
c = capacity(traces) - 1
else
c = capacity(traces)
end
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
CircularVectorBuffer{Int}(c),
SumTree(c),
Expand All @@ -34,6 +38,22 @@ function Base.push!(t::CircularPrioritizedTraces, x)
end
end

function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x)
initial_length = length(t.traces)
push!(t.traces, x)
if length(t.traces) == 1
push!(t.keys, 1)
push!(t.priorities, t.default_priority)
elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 )
# only add a key if the length changes after insertion of the tuple
# or if the trace is already at capacity
push!(t.keys, t.keys[end] + 1)
push!(t.priorities, t.default_priority)
else
# may be partial inserting at the first step, ignore it
end
end

function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
if k === :priority
@assert length(vs) == length(keys)
Expand All @@ -48,6 +68,7 @@ function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
end

Base.size(t::CircularPrioritizedTraces) = size(t.traces)
max_length(t::CircularPrioritizedTraces) = max_length(t.traces)

function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
if s === :priority
Expand Down
34 changes: 30 additions & 4 deletions src/episodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)

fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)

max_length(eb::EpisodesBuffer) = max_length(eb.traces)

function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
push!(eb.traces, xs)
partial = ispartial_insert(eb, xs)
Expand All @@ -146,10 +148,12 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
push!(eb.episodes_lengths, 0)
push!(eb.sampleable_inds, 0)
elseif !partial #typical inserting
if length(eb.traces) < length(eb) && length(eb) > 2 #case when PartialNamedTuple is used. Steps are indexable one step later
eb.sampleable_inds[end-1] = 1
else #case when we don't, length of traces and eb will match.
eb.sampleable_inds[end] = 1 #previous step is now indexable
if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched
if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode
eb.sampleable_inds[end-1] = 1 # steps are indexable one step later
end
else
eb.sampleable_inds[end] = 1 # otherwise, previous step is now indexable
end
push!(eb.sampleable_inds, 0) #this one is no longer
ep_length = last(eb.step_numbers)
Expand All @@ -172,6 +176,28 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
end

function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple)
if max_length(eb) == capacity(eb.traces)
popfirst!(eb)
end
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
end

function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}})
if max_length(eb) == capacity(eb.traces)
addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal])
xs = merge(xs.namedtuple, addition)
push!(eb.traces, xs)
pop!(eb.traces[:state].trace)
pop!(eb.traces[:reward])
pop!(eb.traces[:terminal])
else
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1
end
end

for f in (:pop!, :popfirst!)
@eval function Base.$f(eb::EpisodesBuffer)
$f(eb.episodes_lengths)
Expand Down
6 changes: 3 additions & 3 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
t = e.traces
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
w .*= e.sampleable_inds[1:end-1]
w .*= e.sampleable_inds[1:length(t)]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
end
Expand Down Expand Up @@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
w .*= valids[1:length(t)]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=p[inds]),
Expand Down Expand Up @@ -362,7 +362,7 @@ function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
w .*= valids[1:length(t)]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=p[inds]),
Expand Down
1 change: 1 addition & 0 deletions src/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T
end

Base.size(t::Traces) = (mapreduce(length, min, t.traces),)
max_length(t::Traces) = mapreduce(length, max, t.traces)

function capacity(t::Traces{names,Trs,N,E}) where {names,Trs,N,E}
minimum(map(idx->capacity(t[idx]), names))
Expand Down
80 changes: 70 additions & 10 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@test length(t) == 0
end

@testset "CircularArraySARTSTraces" begin
@testset "CircularArraySARTSATraces" begin
t = CircularArraySARTSATraces(;
capacity=3,
state=Float32 => (2, 3),
Expand All @@ -35,13 +35,14 @@ end

@test t isa CircularArraySARTSATraces

push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu)
push!(t, (state=ones(Float32, 2, 3),))
push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu)
@test length(t) == 0

push!(t, (reward=1.0f0, terminal=false) |> gpu)
@test length(t) == 0 # next_state and next_action is still missing
@test length(t) == 0 # next_action is still missing

push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu)
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu)
@test length(t) == 1

# this will trigger the scalar indexing of CuArray
Expand All @@ -55,17 +56,18 @@ end
)

push!(t, (reward=2.0f0, terminal=false))
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu)
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu)

@test length(t) == 2

push!(t, (reward=3.0f0, terminal=false))
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu)
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu)

@test length(t) == 3

push!(t, (reward=4.0f0, terminal=false))
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu)
push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu)
push!(t, (reward=5.0f0, terminal=false))

@test length(t) == 3

Expand Down Expand Up @@ -127,9 +129,9 @@ end
@test t isa CircularArraySLARTTraces
end

@testset "CircularPrioritizedTraces" begin
@testset "CircularPrioritizedTraces-SARTS" begin
t = CircularPrioritizedTraces(
CircularArraySARTSATraces(;
CircularArraySARTSTraces(;
capacity=3
),
default_priority=1.0f0
Expand Down Expand Up @@ -160,7 +162,7 @@ end

#EpisodesBuffer
t = CircularPrioritizedTraces(
CircularArraySARTSATraces(;
CircularArraySARTSTraces(;
capacity=10
),
default_priority=1.0f0
Expand All @@ -186,3 +188,61 @@ end
eb[:priority, [1, 2]] = [0, 0]
@test eb[:priority] == [zeros(2);ones(8)]
end

@testset "CircularPrioritizedTraces-SARTSA" begin
t = CircularPrioritizedTraces(
CircularArraySARTSATraces(;
capacity=3
),
default_priority=1.0f0
)

push!(t, (state=0, action=0))

for i in 1:5
push!(t, (reward=1.0f0, terminal=false, state=i, action=i))
end

@test length(t) == 3

s = BatchSampler(5)

b = sample(s, t)

t[:priority, [1, 2]] = [0, 0]

# shouldn't be changed since [1,2] are old keys
@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]

t[:priority, [3, 4, 5]] = [0, 1, 0]

b = sample(s, t)

@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0

#EpisodesBuffer
t = CircularPrioritizedTraces(
CircularArraySARTSATraces(;
capacity=10
),
default_priority=1.0f0
)

eb = EpisodesBuffer(t)
push!(eb, (state = 1,))
for i = 1:5
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
end
push!(eb, PartialNamedTuple((action = 6,)))
push!(eb, (state = 7,))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
end
push!(eb, PartialNamedTuple((action=12,)))
s = BatchSampler(1000)
b = sample(s, eb)
cm = counter(b[:state])
@test !haskey(cm, 6)
@test !haskey(cm, 11)
@test all(in(keys(cm)), [1:5;7:10])
end
34 changes: 26 additions & 8 deletions test/episodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ using Test
for i = 1:5
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.sampleable_inds[end-1] == 0
if length(eb) >= 1
@test eb.sampleable_inds[end-2] == 1
end
@test eb.step_numbers[end] == i + 1
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
end
Expand All @@ -123,18 +126,24 @@ using Test
ep2_len += 1
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.sampleable_inds[end-1] == 0
if eb.step_numbers[end] > 2
@test eb.sampleable_inds[end-2] == 1
end
@test eb.step_numbers[end] == j + 1
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
end
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0]
@test length(eb.traces) == 9 #an action is missing at this stage
#three last steps replace oldest steps in the buffer.
for (i, s) = enumerate(12:13)
ep2_len += 1
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.sampleable_inds[end-1] == 0
if eb.step_numbers[end] > 2
@test eb.sampleable_inds[end-2] == 1
end
@test eb.step_numbers[end] == i + 1 + 4
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
end
Expand Down Expand Up @@ -299,7 +308,10 @@ using Test
for i = 1:5
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.sampleable_inds[end-1] == 0
if eb.step_numbers[end] > 2
@test eb.sampleable_inds[end-2] == 1
end
@test eb.step_numbers[end] == i + 1
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
end
Expand All @@ -321,17 +333,23 @@ using Test
ep2_len += 1
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.sampleable_inds[end-1] == 0
if eb.step_numbers[end] > 2
@test eb.sampleable_inds[end-2] == 1
end
@test eb.step_numbers[end] == j + 1
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
end
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0]
@test length(eb.traces) == 9 #an action is missing at this stage
for (i, s) = enumerate(12:13)
ep2_len += 1
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.sampleable_inds[end-1] == 0
if eb.step_numbers[end] > 2
@test eb.sampleable_inds[end-2] == 1
end
@test eb.step_numbers[end] == i + 1 + 4
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
end
Expand Down
14 changes: 8 additions & 6 deletions test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,17 @@ import ReinforcementLearningTrajectories.fetch
batchsize = 4
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize)

push!(eb, (state = 1, action = 1))
push!(eb, (state = 1,))
for i = 1:5
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5))
end
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
push!(eb, PartialNamedTuple((action=6,)))
push!(eb, (state = 7,))
for (j,i) = enumerate(7:10)
push!(eb, (state = i+1, action =i, reward = i, terminal = i==10))
end
push!(eb, PartialNamedTuple((action = 11,)))
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
inds = [i for i in eachindex(weights) if weights[i] == 1]
batch = sample(s1, eb)
Expand Down

0 comments on commit de01fb3

Please sign in to comment.