From 518f00b84be4400b80479178e9ce1fb4b81ac9af Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 20 May 2022 18:13:59 +0800 Subject: [PATCH 01/17] update latest upstream --- src/traces.jl | 236 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 188 insertions(+), 48 deletions(-) diff --git a/src/traces.jl b/src/traces.jl index 02a96ab..0985b1c 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -1,52 +1,91 @@ export Trace, Traces, sample +import MacroTools: @forward + """ - Trace(data) + AbstractTrace{T} <: AbstractVector{T} + +An `AbstractTrace` is a subtype of `AbstractVector`. The following methods should be implemented: + +- `Base.length` +- `Base.firstindex` +- `Base.lastindex` +- `Base.getindex` +- `Base.view` +- `Base.push!` +- `Base.append!` +- `Base.empty!` +- `Base.pop!` +- `Base.popfirst!` +""" +abstract type AbstractTrace{T} <: AbstractVector{T} end -A wrapper of arbitrary container. Generally we assume the `data` is an -`AbstractVector` like object. When an `AbstractArray` is given, we view it as a -vector of sub-arrays along the last dimension. """ -struct Trace{T} - x::T -end + AbstractTraces{names} + +An `AbstractTraces` is a group of different [`AbstractTrace`](@ref). Following methods must be implemented: + +- `Base.getindex`, get the inner `AbstractTrace` given a trace name. +- `Base.keys` +- `Base.haskey` +- `Base.push!` +- `Base.append!` +- `Base.pop!` +- `Base.popfirst!` +- `Base.empty!` +""" +abstract type AbstractTraces{names} end -Base.length(t::Trace) = length(t.x) -Base.length(t::Trace{<:AbstractArray}) = size(t.x, ndims(t.x)) +Base.keys(t::AbstractTraces{names}) where {names} = names +Base.haskey(t::AbstractTraces{names}) where {names} = haskey(names) -Base.lastindex(t::Trace) = length(t) -Base.firstindex(t::Trace) = 1 +Base.push!(t::AbstractTraces; kw...) = push!(t, values(kw)) -Base.convert(::Type{Trace}, x) = Trace(x) +function Base.push!(t::AbstractTraces, x::NamedTuple) + for k in keys(x) + push!(t[k], x[k]) + end +end -Base.getindex(t::Trace{<:AbstractVector}, I...) = getindex(t.x, I...) -Base.view(t::Trace{<:AbstractVector}, I...) = view(t.x, I...) +Base.append!(t::AbstractTraces; kw...) = append!(t, values(kw)) -Base.getindex(t::Trace{<:AbstractArray}, I...) = getindex(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) -Base.view(t::Trace{<:AbstractArray}, I...) = view(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) +function Base.append!(t::AbstractTraces, x::NamedTuple) + for k in keys(x) + append!(t[k], x[k]) + end +end -Base.push!(t::Trace, x) = push!(t.x, x) -Base.append!(t::Trace, x) = append!(t.x, x) -Base.pop!(t::Trace) = pop!(t.x) -Base.popfirst!(t::Trace) = popfirst!(t.x) -Base.empty!(t::Trace) = empty!(t.x) +##### -## +""" + Trace(data) -function sample(s::BatchSampler, t::Trace) - inds = rand(s.rng, 1:length(t), s.batch_size) - t[inds] |> s.transformer +The most common [`AbstractTrace`](@ref). A wrapper of arbitrary container. +Generally we assume the `data` is an `AbstractVector` like object. When an +`AbstractArray` is given, we view it as a vector of sub-arrays along the last +dimension. +""" +struct Trace{T} <: AbstractTrace + x::T end +@forward Trace.x Base.length, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.push!, Base.append!, Base.pop!, Base.popfirst!, Base.empty! + +Base.convert(::Type{Trace}, x) = Trace(x) + +Base.length(t::Trace{<:AbstractArray}) = size(t.x, ndims(t.x)) +Base.getindex(t::Trace{<:AbstractArray}, I...) = getindex(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) +Base.view(t::Trace{<:AbstractArray}, I...) = view(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) + ##### """ Traces(;kw...) -A container of several named-[`Trace`](@ref)s. Each element in the `kw` will be converted into a `Trace`. +A container of several named-[`AbstractTrace`](@ref)s. Each element in the `kw` will be converted into a `Trace`. """ -struct Traces{names,T} +struct Traces{names,T} <: AbstractTraces traces::NamedTuple{names,T} function Traces(; kw...) traces = map(x -> convert(Trace, x), values(kw)) @@ -54,35 +93,136 @@ struct Traces{names,T} end end -Base.keys(t::Traces) = keys(t.traces) -Base.haskey(t::Traces, s::Symbol) = haskey(t.traces, s) -Base.getindex(t::Traces, x) = getindex(t.traces, x) -Base.length(t::Traces) = mapreduce(length, min, t.traces) +@forward Traces.traces Base.getindex -Base.push!(t::Traces; kw...) = push!(t, values(kw)) +Base.pop!(t::Traces) = map(pop!, t.traces) +Base.popfirst!(t::Traces) = map(popfirst!, t.traces) +Base.empty!(t::Traces) = map(empty!, t.traces) -function Base.push!(t::Traces, x::NamedTuple) - for k in keys(x) - push!(t[k], x[k]) +##### + +""" + MultiplexTraces{names}(trace) + +A special [`AbstractTraces`](@ref) which has exactly two traces of the same +length. And those two traces share the header and tail part. + +For example, if a `trace` contains elements between 0 and 9, then the first +`trace_A` is a view of elements from 0 to 8 and the second one is a view from 1 +to 9. + +``` + ┌─────trace_A───┐ +trace 0 1 2 3 4 5 6 7 8 9 + └────trace_B────┘ +``` + +This is quite common in RL to represent `states` and `next_states`. +""" +struct MultiplexTraces{names,T} <: AbstractTraces{names} + trace::Trace{T} +end + +function MultiplexTraces{names}(trace) where {names} + if length(names) != 2 + throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names")) end + t = convert(Trace, trace) + MultiplexTraces{names,typeof(t)}(t) end -Base.append!(t::Traces; kw...) = append!(t, values(kw)) +@forward MultiplexTraces.trace Base.pop!, Base.popfirst!, Base.empty! -function Base.append!(t::Traces, x::NamedTuple) - for k in keys(x) - append!(t[k], x[k]) +Base.getindex(t::MultiplexTraces, i::Int) = getindex(t, keys(t)[i]) + +function Base.getindex(t::MultiplexTraces, k::Symbol) + a, b = keys(t) + if k == a + @view t.trace[1:end-1] + elseif k == b + @view t.trace[2:end] + else + throw(ArgumentError("unknown trace name: $k")) end end -Base.pop!(t::Traces) = map(pop!, t.traces) -Base.popfirst!(t::Traces) = map(popfirst!, t.traces) -Base.empty!(t::Traces) = map(empty!, t.traces) +function Base.push!(t::MultiplexTraces, x::NamedTuple{ks,Tuple{Ts}}) where {ks,Ts} + k, v = first(ks), first(x) + if k in keys(t) + push!(t.trace, v) + else + throw(ArgumentError("unknown trace name: $k")) + end +end + +function Base.append!(t::MultiplexTraces, x::NamedTuple{ks,Tuple{Ts}}) where {ks,Ts} + k, v = first(ks), first(x) + if k in keys(t) + append!(t.trace, v) + else + throw(ArgumentError("unknown trace name: $k")) + end +end + +##### + +struct MergedTraces{names,T,N} <: AbstractTraces{names} + traces::T + inds::NamedTuple{names,NTuple{N,Int}} +end -## -function sample(s::BatchSampler, t::Traces) - inds = rand(s.rng, 1:length(t), s.batch_size) - map(t.traces) do x - x[inds] - end |> s.transformer +function Base.(:*)(t1::AbstractTraces, t2::AbstractTraces) + k1, k2 = keys(t1), keys(t2) + ks = (k1..., k2...) + ts = (t1, t2) + inds = (; (k => 1 for k in k1)..., (k => 2 for k in k2)...) + MergedTraces{ks,typeof(ts)}(ts, inds) +end + +function Base.(:*)(t1::AbstractTraces, t2::MergedTraces) + k1, k2 = keys(t1), keys(t2) + ks = (k1..., k2...) + ts = (t1, t2.traces...) + inds = (; (k => 1 for k in k1)..., map(x -> x + 1, t2.inds)...) + MergedTraces{ks,typeof(ts)}(ts, inds) +end + +function Base.(:*)(t1::MergedTraces, t2::AbstractTraces) + k1, k2 = keys(t1), keys(t2) + ks = (k1..., k2...) + ts = (t1.traces..., t2) + inds = merge(t1.inds, (; (k => length(t1.traces) + 1 for k in k2)...)) + MergedTraces{ks,typeof(ts)}(ts, inds) +end + +function Base.push!(ts::MergedTraces, xs::NamedTuple) + for (k, v) in pairs(xs) + t = ts.traces[t.inds[k]] + push!(t, v) + end +end + +function Base.append!(ts::MergedTraces, xs::NamedTuple) + for (k, v) in pairs(xs) + t = ts.traces[t.inds[k]] + append!(t, v) + end +end + +function Base.pop!(ts::MergedTraces) + for t in ts.traces + pop!(t) + end +end + +function Base.popfirst!(ts::MergedTraces) + for t in ts.traces + popfirst!(t) + end +end + +function Base.empty!(ts::MergedTraces) + for t in ts.traces + empty!(t) + end end \ No newline at end of file From c9d97b7c9ca736f0c0741aebb5358afcea3a08de Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 22 May 2022 17:53:44 +0800 Subject: [PATCH 02/17] use StructArrays --- .gitignore | 3 + Project.toml | 1 + src/LastDimSlices.jl | 27 +++ src/Trajectories.jl | 7 +- src/common/CircularArraySARTTraces.jl | 37 +--- src/common/CircularArraySLARTTraces.jl | 46 +---- src/episodes.jl | 43 ++--- src/samplers.jl | 14 +- src/traces.jl | 233 ++++++++++--------------- src/trajectory.jl | 4 - test/traces.jl | 140 ++++++++++----- test/trajectories.jl | 12 +- 12 files changed, 265 insertions(+), 302 deletions(-) create mode 100644 src/LastDimSlices.jl diff --git a/.gitignore b/.gitignore index 0f84bed..86a3ec9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ +coverage/ *.jl.*.cov *.jl.cov *.jl.mem /Manifest.toml + +.DS_Store \ No newline at end of file diff --git a/Project.toml b/Project.toml index efa7a7b..bc8814c 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" [compat] diff --git a/src/LastDimSlices.jl b/src/LastDimSlices.jl new file mode 100644 index 0000000..a7e0030 --- /dev/null +++ b/src/LastDimSlices.jl @@ -0,0 +1,27 @@ +export LastDimSlices + +using MacroTools: @forward + +# See also https://github.com/JuliaLang/julia/pull/32310 + +struct LastDimSlices{T,E} <: AbstractVector{E} + parent::T +end + +function LastDimSlices(x::T) where {T<:AbstractArray} + E = eltype(x) + N = ndims(x) - 1 + P = typeof(x) + I = Tuple{ntuple(_ -> Base.Slice{Base.OneTo{Int}}, Val(ndims(x) - 1))...,Int} + LastDimSlices{T,SubArray{E,N,P,I,true}}(x) +end + +Base.convert(::Type{LastDimSlices}, x::AbstractVector) = x +Base.convert(::Type{LastDimSlices}, x::AbstractArray) = LastDimSlices(x) + +Base.size(x::LastDimSlices) = (size(x.parent, ndims(x.parent)),) +Base.getindex(s::LastDimSlices, I) = getindex(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) +Base.view(s::LastDimSlices, I) = view(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) +Base.setindex!(s::LastDimSlices, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) + +@forward LastDimSlices.parent Base.parent, Base.pushfirst!, Base.push!, Base.pop!, Base.append!, Base.prepend!, Base.empty! \ No newline at end of file diff --git a/src/Trajectories.jl b/src/Trajectories.jl index 40e876b..fba18fa 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -1,11 +1,12 @@ module Trajectories -include("samplers.jl") -include("controlers.jl") +include("LastDimSlices.jl") include("traces.jl") include("episodes.jl") +include("samplers.jl") +include("controlers.jl") include("trajectory.jl") -include("rendering.jl") +# include("rendering.jl") include("common/common.jl") end diff --git a/src/common/CircularArraySARTTraces.jl b/src/common/CircularArraySARTTraces.jl index f140b00..0d77547 100644 --- a/src/common/CircularArraySARTTraces.jl +++ b/src/common/CircularArraySARTTraces.jl @@ -1,16 +1,5 @@ export CircularArraySARTTraces -const CircularArraySARTTraces = Traces{ - SART, - <:Tuple{ - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer} - } -} - - function CircularArraySARTTraces(; capacity::Int, state=Int => (), @@ -23,32 +12,10 @@ function CircularArraySARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal + MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( - state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer - action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) end - -function Random.rand(s::BatchSampler, t::CircularArraySARTTraces) - inds = rand(s.rng, 1:length(t), s.batch_size) - inds′ = inds .+ 1 - ( - state=t[:state][inds], - action=t[:action][inds], - reward=t[:reward][inds], - terminal=t[:terminal][inds], - next_state=t[:state][inds′], - next_action=t[:state][inds′] - ) |> s.transformer -end - -function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA}) - if length(t[:state]) == length(t[:terminal]) + 1 - pop!(t[:state]) - pop!(t[:action]) - end - push!(t[:state], x[:state]) - push!(t[:action], x[:action]) -end diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index 83e5d0d..66e3a2f 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -1,17 +1,5 @@ export CircularArraySLARTTraces -const CircularArraySLARTTraces = Traces{ - SLART, - <:Tuple{ - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer} - } -} - - function CircularArraySLARTTraces(; capacity::Int, state=Int => (), @@ -26,37 +14,11 @@ function CircularArraySLARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal + MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{(:legal_actions_mask, :next_legal_actions_mask)}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + + MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( - state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer - legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer - action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) -end - -function sample(s::BatchSampler, t::CircularArraySLARTTraces) - inds = rand(s.rng, 1:length(t), s.batch_size) - inds′ = inds .+ 1 - ( - state=t[:state][inds], - legal_actions_mask=t[:legal_actions_mask][inds], - action=t[:action][inds], - reward=t[:reward][inds], - terminal=t[:terminal][inds], - next_state=t[:state][inds′], - next_legal_actions_mask=t[:legal_actions_mask][inds′], - next_action=t[:state][inds′] - ) |> s.transformer -end - -function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA}) - if length(t[:state]) == length(t[:terminal]) + 1 - pop!(t[:state]) - pop!(t[:legal_actions_mask]) - pop!(t[:action]) - end - push!(t[:state], x[:state]) - push!(t[:legal_actions_mask], x[:legal_actions_mask]) - push!(t[:action], x[:action]) -end +end \ No newline at end of file diff --git a/src/episodes.jl b/src/episodes.jl index dcd5712..7932fea 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -1,6 +1,5 @@ export Episode, Episodes -using MLUtils: batch """ Episode(traces) @@ -8,34 +7,26 @@ using MLUtils: batch An `Episode` is a wrapper around [`Traces`](@ref). You can use `(e::Episode)[]` to check/update whether the episode reaches a terminal or not. """ -struct Episode{T} +struct Episode{T,E} <: AbstractVector{E} traces::T - is_done::Ref{Bool} + is_terminated::Ref{Bool} end -Base.getindex(e::Episode, s::Symbol) = getindex(e.traces, s) -Base.keys(e::Episode) = keys(e.traces) - +Base.getindex(e::Episode, I) = getindex(e.traces, I) Base.getindex(e::Episode) = getindex(e.is_done) Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_done, x) -Base.length(e::Episode) = length(e.traces) +Base.size(e::Episode) = size(e.traces) -Episode(t::Traces) = Episode(t, Ref(false)) +Episode(t::T) where {T<:AbstractTraces} = Episode{T,eltype(t)}(t, Ref(false)) -function Base.push!(t::Episode, x) - if t.is_done[] - throw(ArgumentError("The episode is already flagged as done!")) - else - push!(t.traces, x) - end -end - -function Base.append!(t::Episode, x) - if t.is_done[] - throw(ArgumentError("The episode is already flagged as done!")) - else - append!(t.traces, x) +for f in (:push!, :pushfirst!, :append!, :prepend!) + @eval function Base.$f(t::Episode, x) + if t.is_done[] + throw(ArgumentError("The episode is already flagged as done!")) + else + $f(t.traces, x) + end end end @@ -58,13 +49,14 @@ end A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an [`Episode`](@ref). """ -struct Episodes +struct Episodes <: AbstractVector{Episode} init::Any episodes::Vector{Episode} inds::Vector{Tuple{Int,Int}} end -Base.length(e::Episodes) = length(e.inds) +Base.size(e::Episodes) = size(e.inds) +Base.getindex(e::Episodes, I) = getindex(e.episodes, I) function Base.push!(e::Episodes, x::Episode) push!(e.episodes, x) @@ -100,8 +92,3 @@ function Base.append!(e::Episodes, x) end ## - -function sample(s::BatchSampler, e::Episodes) - inds = rand(s.rng, 1:length(t), s.batch_size) - batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer -end \ No newline at end of file diff --git a/src/samplers.jl b/src/samplers.jl index b4c073c..f1afb25 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -1,5 +1,7 @@ export BatchSampler +using MLUtils: batch + using Random struct BatchSampler @@ -15,4 +17,14 @@ Uniformly sample a batch of examples for each trace. See also [`sample`](@ref). """ -BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, identity) +BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, transformer) + +function sample(s::BatchSampler, t::AbstractTraces) + inds = rand(s.rng, 1:length(t), s.batch_size) + @view t[inds] +end + +function sample(s::BatchSampler, e::Episodes) + inds = rand(s.rng, 1:length(t), s.batch_size) + batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer +end \ No newline at end of file diff --git a/src/traces.jl b/src/traces.jl index 0985b1c..ba28642 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -1,103 +1,40 @@ -export Trace, Traces, sample +export Trace, Traces, MultiplexTraces +using StructArrays import MacroTools: @forward -""" - AbstractTrace{T} <: AbstractVector{T} - -An `AbstractTrace` is a subtype of `AbstractVector`. The following methods should be implemented: - -- `Base.length` -- `Base.firstindex` -- `Base.lastindex` -- `Base.getindex` -- `Base.view` -- `Base.push!` -- `Base.append!` -- `Base.empty!` -- `Base.pop!` -- `Base.popfirst!` -""" -abstract type AbstractTrace{T} <: AbstractVector{T} end - -""" - AbstractTraces{names} - -An `AbstractTraces` is a group of different [`AbstractTrace`](@ref). Following methods must be implemented: - -- `Base.getindex`, get the inner `AbstractTrace` given a trace name. -- `Base.keys` -- `Base.haskey` -- `Base.push!` -- `Base.append!` -- `Base.pop!` -- `Base.popfirst!` -- `Base.empty!` -""" -abstract type AbstractTraces{names} end - -Base.keys(t::AbstractTraces{names}) where {names} = names -Base.haskey(t::AbstractTraces{names}) where {names} = haskey(names) - -Base.push!(t::AbstractTraces; kw...) = push!(t, values(kw)) - -function Base.push!(t::AbstractTraces, x::NamedTuple) - for k in keys(x) - push!(t[k], x[k]) - end -end - -Base.append!(t::AbstractTraces; kw...) = append!(t, values(kw)) - -function Base.append!(t::AbstractTraces, x::NamedTuple) - for k in keys(x) - append!(t[k], x[k]) - end -end - - -##### - -""" - Trace(data) - -The most common [`AbstractTrace`](@ref). A wrapper of arbitrary container. -Generally we assume the `data` is an `AbstractVector` like object. When an -`AbstractArray` is given, we view it as a vector of sub-arrays along the last -dimension. -""" -struct Trace{T} <: AbstractTrace - x::T -end - -@forward Trace.x Base.length, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.push!, Base.append!, Base.pop!, Base.popfirst!, Base.empty! +abstract type AbstractTraces{names,T} <: AbstractVector{NamedTuple{names,T}} end -Base.convert(::Type{Trace}, x) = Trace(x) - -Base.length(t::Trace{<:AbstractArray}) = size(t.x, ndims(t.x)) -Base.getindex(t::Trace{<:AbstractArray}, I...) = getindex(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) -Base.view(t::Trace{<:AbstractArray}, I...) = view(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) - -##### +# function Base.show(io::IO, ::MIME"text/plain", t::AbstractTraces{names}) where {names} +# println(io, "$(length(names)) traces in total with $(length(t)) elements:") +# for n in names +# println(" :$n => $(summary(t[n]))") +# end +# end """ Traces(;kw...) - -A container of several named-[`AbstractTrace`](@ref)s. Each element in the `kw` will be converted into a `Trace`. """ -struct Traces{names,T} <: AbstractTraces - traces::NamedTuple{names,T} +struct Traces{T,names,E} <: AbstractTraces{names,E} + traces::T function Traces(; kw...) - traces = map(x -> convert(Trace, x), values(kw)) - new{keys(traces),typeof(values(traces))}(traces) + for (k, v) in kw + if !(v isa AbstractVector) + throw(ArgumentError("the value of $k should be an AbstractVector")) + end + end + + data = map(x -> convert(LastDimSlices, x), values(kw)) + t = StructArray(data) + new{typeof(t),keys(data),Tuple{typeof(data).types...}}(t) end end -@forward Traces.traces Base.getindex +@forward Traces.traces Base.size, Base.parent, Base.getindex, Base.setindex!, Base.view, Base.push!, Base.pushfirst!, Base.pop!, Base.popfirst!, Base.empty! -Base.pop!(t::Traces) = map(pop!, t.traces) -Base.popfirst!(t::Traces) = map(popfirst!, t.traces) -Base.empty!(t::Traces) = map(empty!, t.traces) +Base.append!(t::Traces, x::NamedTuple) = append!(t.traces, StructArray(x)) +Base.prepend!(t::Traces, x::NamedTuple) = prepend!(t.traces, StructArray(x)) +Base.getindex(t::Traces, s::Symbol) = getproperty(t.traces, s) ##### @@ -119,24 +56,20 @@ trace 0 1 2 3 4 5 6 7 8 9 This is quite common in RL to represent `states` and `next_states`. """ -struct MultiplexTraces{names,T} <: AbstractTraces{names} - trace::Trace{T} +struct MultiplexTraces{names,T,E} <: AbstractTraces{names,Tuple{E,E}} + trace::T end -function MultiplexTraces{names}(trace) where {names} +function MultiplexTraces{names}(t) where {names} if length(names) != 2 throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names")) end - t = convert(Trace, trace) - MultiplexTraces{names,typeof(t)}(t) + trace = convert(LastDimSlices, t) + MultiplexTraces{names,typeof(trace),eltype(trace)}(trace) end -@forward MultiplexTraces.trace Base.pop!, Base.popfirst!, Base.empty! - -Base.getindex(t::MultiplexTraces, i::Int) = getindex(t, keys(t)[i]) - -function Base.getindex(t::MultiplexTraces, k::Symbol) - a, b = keys(t) +function Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names} + a, b = names if k == a @view t.trace[1:end-1] elseif k == b @@ -146,83 +79,95 @@ function Base.getindex(t::MultiplexTraces, k::Symbol) end end -function Base.push!(t::MultiplexTraces, x::NamedTuple{ks,Tuple{Ts}}) where {ks,Ts} - k, v = first(ks), first(x) - if k in keys(t) - push!(t.trace, v) - else - throw(ArgumentError("unknown trace name: $k")) - end +Base.getindex(t::MultiplexTraces{names}, I::Int) where {names} = NamedTuple{names}(t[k][I] for k in names) +Base.getindex(t::MultiplexTraces{names}, I) where {names} = StructArray(NamedTuple{names}(t[k][I] for k in names)) +Base.view(t::MultiplexTraces{names}, I) where {names} = StructArray(NamedTuple{names}(view(t[k], I) for k in names)) +Base.size(t::MultiplexTraces) = (max(0, length(t.trace) - 1),) + +function Base.setindex!(t::MultiplexTraces{names}, v::NamedTuple, i) where {names} + a, b = names + va, vb = getindex(v, a), getindex(v, b) + t.trace[i] = va + t.trace[i+1] = vb end -function Base.append!(t::MultiplexTraces, x::NamedTuple{ks,Tuple{Ts}}) where {ks,Ts} - k, v = first(ks), first(x) - if k in keys(t) - append!(t.trace, v) - else - throw(ArgumentError("unknown trace name: $k")) +@forward MultiplexTraces.trace Base.parent, Base.pop!, Base.popfirst!, Base.empty! + +for f in (:push!, :pushfirst!, :append!, :prepend!) + @eval function Base.$f(t::MultiplexTraces{names}, x::NamedTuple{ks,Tuple{Ts}}) where {names,ks,Ts} + k, v = first(ks), first(x) + if k in names + $f(t.trace, v) + else + throw(ArgumentError("unknown trace name: $k")) + end end end ##### -struct MergedTraces{names,T,N} <: AbstractTraces{names} +struct MergedTraces{names,T,N,E} <: AbstractTraces{names,E} traces::T inds::NamedTuple{names,NTuple{N,Int}} end -function Base.(:*)(t1::AbstractTraces, t2::AbstractTraces) - k1, k2 = keys(t1), keys(t2) +Base.getindex(ts::MergedTraces, s::Symbol) = ts.traces[ts.inds[s]][s] + +function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::AbstractTraces{k2,T2}) where {k1,k2,T1,T2} ks = (k1..., k2...) ts = (t1, t2) inds = (; (k => 1 for k in k1)..., (k => 2 for k in k2)...) - MergedTraces{ks,typeof(ts)}(ts, inds) + MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) end -function Base.(:*)(t1::AbstractTraces, t2::MergedTraces) - k1, k2 = keys(t1), keys(t2) +function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::MergedTraces{k2,T,N,T2}) where {k1,T1,k2,T,N,T2} ks = (k1..., k2...) ts = (t1, t2.traces...) - inds = (; (k => 1 for k in k1)..., map(x -> x + 1, t2.inds)...) - MergedTraces{ks,typeof(ts)}(ts, inds) + inds = merge(NamedTuple(k => 1 for k in k1), map(v => v + 1, t1.inds)) + MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) end -function Base.(:*)(t1::MergedTraces, t2::AbstractTraces) - k1, k2 = keys(t1), keys(t2) + +function Base.:(+)(t1::MergedTraces{k1,T,N,T1}, t2::AbstractTraces{k2,T2}) where {k1,T,N,T1,k2,T2} ks = (k1..., k2...) ts = (t1.traces..., t2) - inds = merge(t1.inds, (; (k => length(t1.traces) + 1 for k in k2)...)) - MergedTraces{ks,typeof(ts)}(ts, inds) + inds = merge(t1.inds, (; (k => length(ts) for k in k2)...)) + MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) end -function Base.push!(ts::MergedTraces, xs::NamedTuple) - for (k, v) in pairs(xs) - t = ts.traces[t.inds[k]] - push!(t, v) - end +function Base.:(+)(t1::MergedTraces{k1,T1,N1,E1}, t2::MergedTraces{k2,T2,N2,E2}) where {k1,T1,N1,E1,k2,T2,N2,E2} + ks = (k1..., k2...) + ts = (t1.traces..., t2.traces...) + inds = merge(t1.inds, map(x -> x + length(t1.traces), t2.inds)) + MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) end -function Base.append!(ts::MergedTraces, xs::NamedTuple) - for (k, v) in pairs(xs) - t = ts.traces[t.inds[k]] - append!(t, v) - end -end -function Base.pop!(ts::MergedTraces) - for t in ts.traces - pop!(t) +Base.size(t::MergedTraces) = size(t.traces[1]) +Base.getindex(t::MergedTraces, I::Int) = mapreduce(x -> getindex(x, I), merge, t.traces) +Base.getindex(t::MergedTraces, I) = StructArray(mapreduce(x -> getfield(getindex(x, I), :components), merge, t.traces)) +Base.view(t::MergedTraces, I) = StructArray(mapreduce(x -> getfield(view(x, I), :components), merge, t.traces)) + +function Base.setindex!(t::MergedTraces, x::NamedTuple, I) + for (k, v) in pairs(x) + setindex!(t.traces[t.inds[k]], (; k => v), I) end end -function Base.popfirst!(ts::MergedTraces) - for t in ts.traces - popfirst!(t) + +for f in (:push!, :pushfirst!, :append!, :prepend!) + @eval function Base.$f(ts::MergedTraces, xs::NamedTuple) + for (k, v) in pairs(xs) + t = ts.traces[ts.inds[k]] + $f(t, (; k => v)) + end end end -function Base.empty!(ts::MergedTraces) - for t in ts.traces - empty!(t) +for f in (:pop!, :popfirst!, :empty!) + @eval function Base.$f(ts::MergedTraces) + for t in ts.traces + $f(t) + end end -end \ No newline at end of file +end diff --git a/src/trajectory.jl b/src/trajectory.jl index fe7373e..4e0bc70 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -54,8 +54,6 @@ Base.@kwdef struct Trajectory{C,S,T} end -Base.push!(t::Trajectory; kw...) = push!(t, values(kw)) - function Base.push!(t::Trajectory, x) n_pre = length(t.container) push!(t.container, x) @@ -72,8 +70,6 @@ end Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.push!, args, kw)) Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.append!, args, kw)) -Base.append!(t::Trajectory; kw...) = append!(t, values(kw)) - function Base.append!(t::Trajectory, x) n_pre = length(t.container) append!(t.container, x) diff --git a/test/traces.jl b/test/traces.jl index cfcd72e..593c1aa 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -1,57 +1,119 @@ -@testset "Trace 1d" begin - t = Trace([]) - @test length(t) == 0 +@testset "Traces" begin + t = Traces(; + a=[1, 2], + b=Bool[0, 1] + ) - push!(t, 1) - @test length(t) == 1 - @test t[1] == 1 + @test length(t) == 2 - append!(t, [2, 3]) - @test length(t) == 3 - @test @view(t[2:3]) == [2, 3] + push!(t, (; a=3, b=true)) + + @test t[:a][end] == 3 + @test t[:b][end] == true + + append!(t, (a=[4, 5], b=[false, false])) + @test length(t[:a]) == 5 + @test t[:b][end-1:end] == [false, false] + + @test t[1] == (a=1, b=false) + + t_12 = t[1:2] + @test t_12.a == [1, 2] + @test t_12.b == [false, true] + + t_12.a[1] = 0 + @test t[:a][1] != 0 + + t_12_view = @view t[1:2] + t_12_view.a[1] = 0 + @test t[:a][1] == 0 pop!(t) - @test length(t) == 2 + @test length(t) == 4 - s = BatchSampler(2) - @test size(sample(s, t)) == (2,) + popfirst!(t) + @test length(t) == 3 empty!(t) @test length(t) == 0 - end -@testset "Trace 2d" begin - t = Trace([ - 1 2 3 - 4 5 6 - ]) - @test length(t) == 3 - @test t[1] == [1, 4] - @test @view(t[2:3]) == [2 3; 5 6] +@testset "MultiplexTraces" begin + t = MultiplexTraces{(:state, :next_state)}(Int[]) + + @test length(t) == 0 + + push!(t, (; state=1)) + push!(t, (; next_state=2)) + + @test t[:state] == [1] + @test t[:next_state] == [2] + @test t[1] == (state=1, next_state=2) - s = BatchSampler(5) - @test size(sample(s, t)) == (2, 5) + append!(t, (; state=[3, 4])) + + @test t[:state] == [1, 2, 3] + @test t[:next_state] == [2, 3, 4] + @test t[end] == (state=3, next_state=4) + + pop!(t) + t[end] == (state=2, next_state=3) + empty!(t) + @test length(t) == 0 end -@testset "Traces" begin - t = Traces(; - a=[1, 2], - b=Bool[0, 1] - ) +@testset "MergedTraces" begin + t1 = Traces(a=Int[]) + t2 = Traces(b=Bool[]) - @test keys(t) == (:a, :b) - @test haskey(t, :a) - @test t[:a] isa Trace + t3 = t1 + t2 + @test t3[:a] === t1[:a] + @test t3[:b] === t2[:b] - push!(t; a=3, b=true) - @test t[:a][end] == 3 - @test t[:b][end] == true + push!(t3, (; a=1, b=false)) + @test length(t3) == 1 + @test t3[1] == (a=1, b=false) - append!(t; a=[4, 5], b=[false, false]) - @test length(t[:a]) == 5 - @test t[:b][end-1:end] == [false, false] + append!(t3, (; a=[2, 3], b=[false, true])) + @test length(t3) == 3 + + @test t3[:a][1:3] == [1, 2, 3] + + t3_view = @view t3[1:3] + t3_view.a[1] = 0 + @test t3[:a][1] == 0 + + pop!(t3) + @test length(t3) == 2 + + empty!(t3) + @test length(t3) == 0 + + t4 = MultiplexTraces{(:m, :n)}(Float64[]) + t5 = t4 + t2 + t1 + + push!(t5, (m=1.0, n=1.0, a=1, b=1)) + @test length(t5) == 1 + + push!(t5, (m=2.0, a=2, b=0)) + + @test t5[end] == (m=1.0, n=2.0, b=false, a=2) + + t6 = Traces(aa=Int[]) + t7 = Traces(bb=Bool[]) + t8 = (t1 + t2) + (t6 + t7) + + empty!(t8) + push!(t8, (a=1, b=false, aa=1, bb=false)) + append!(t8, (a=[2, 3], b=[true, true], aa=[2, 3], bb=[true, true])) + + @test length(t8) == 3 + + t8_view = @view t8[2:3] + t8_view.a[1] = 0 + @test t8[:a][2] == 0 - s = BatchSampler(5) - @test size(sample(s, t)[:a]) == (5,) + t8_slice = t8[2:3] + t8_slice.a[1] = -1 + @test t8[:a][2] != -1 end \ No newline at end of file diff --git a/test/trajectories.jl b/test/trajectories.jl index c6b1d9e..7c50718 100644 --- a/test/trajectories.jl +++ b/test/trajectories.jl @@ -16,7 +16,7 @@ @test length(batches) == 0 # threshold not reached yet - append!(t; a=[1, 2, 3], b=[false, true, false]) + append!(t, (a=[1, 2, 3], b=[false, true, false])) for batch in t push!(batches, batch) @@ -24,7 +24,7 @@ @test length(batches) == 0 # threshold not reached yet - push!(t; a=4, b=true) + push!(t, (a=4, b=true)) for batch in t push!(batches, batch) @@ -32,7 +32,7 @@ @test length(batches) == 1 # 4 inserted, threshold is 4, ratio is 0.25 - append!(t; a=[5, 6, 7], b=[true, true, true]) + append!(t, (a=[5, 6, 7], b=[true, true, true])) for batch in t push!(batches, batch) @@ -40,7 +40,7 @@ @test length(batches) == 1 # 7 inserted, threshold is 4, ratio is 0.25 - push!(t; a=8, b=true) + push!(t, (a=8, b=true)) for batch in t push!(batches, batch) @@ -50,7 +50,7 @@ n = 100 for i in 1:n - append!(t; a=[i, i, i, i], b=[false, true, false, true]) + append!(t, (a=[i, i, i, i], b=[false, true, false, true])) end s = 0 @@ -74,7 +74,7 @@ end n = 100 insert_task = @async for i in 1:n - append!(t; a=[i, i, i, i], b=[false, true, false, true]) + append!(t, (a=[i, i, i, i], b=[false, true, false, true])) end s = 0 From 7a27b4ae55b1674d1a75169fbadbf28a5917c84e Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 22 May 2022 21:14:53 +0800 Subject: [PATCH 03/17] merge #9 --- src/Trajectories.jl | 2 +- src/common/common.jl | 8 --- src/episodes.jl | 2 - src/rendering.jl | 135 ------------------------------------------- src/traces.jl | 13 +---- src/trajectory.jl | 13 ++++- 6 files changed, 14 insertions(+), 159 deletions(-) delete mode 100644 src/rendering.jl diff --git a/src/Trajectories.jl b/src/Trajectories.jl index fba18fa..f99d27d 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -6,7 +6,7 @@ include("episodes.jl") include("samplers.jl") include("controlers.jl") include("trajectory.jl") -# include("rendering.jl") +include("rendering.jl") include("common/common.jl") end diff --git a/src/common/common.jl b/src/common/common.jl index 271b149..4377083 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -1,13 +1,5 @@ using CircularArrayBuffers -const SA = (:state, :action) -const SLA = (:state, :legal_actions_mask, :action) -const RT = (:reward, :terminal) -const SART = (:state, :action, :reward, :terminal) -const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) -const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal) -const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action) - include("sum_tree.jl") include("CircularArraySARTTraces.jl") include("CircularArraySLARTTraces.jl") diff --git a/src/episodes.jl b/src/episodes.jl index 7932fea..86812c2 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -90,5 +90,3 @@ function Base.append!(e::Episodes, x) push!(e.inds, (lengthe.episodes, i)) end end - -## diff --git a/src/rendering.jl b/src/rendering.jl deleted file mode 100644 index 1de5afa..0000000 --- a/src/rendering.jl +++ /dev/null @@ -1,135 +0,0 @@ -using Term - -const TRACE_COLORS = ("bright_green", "hot_pink", "bright_blue", "light_coral", "bright_cyan", "sandy_brown", "violet") - -Base.show(io::IO, ::MIME"text/plain", t::Union{Trace,Traces,Episode,Episodes,Trajectory}) = tprint(io, convert(Term.AbstractRenderable, t; width=displaysize(io)[2]) |> string) - -inner_convert(::Type{Term.AbstractRenderable}, s::String; style="gray1", width=88) = Panel(s, width=width, style=style, justify=:center) -inner_convert(t::Type{Term.AbstractRenderable}, x::Union{Symbol,Number}; kw...) = inner_convert(t, string(x); kw...) - -function inner_convert(::Type{Term.AbstractRenderable}, x::AbstractArray; style="gray1", width=88) - t = string(nameof(typeof(x))) - s = replace(string(size(x)), " " => "") - Panel(t * "\n" * s, style=style, justify=:center, width=width) -end - -function inner_convert(::Type{Term.AbstractRenderable}, x; style="gray1", width=88) - s = string(nameof(typeof(x))) - Panel(s, style=style, justify=:center, width=width) -end - -Base.convert(T::Type{Term.AbstractRenderable}, t::Trace{<:AbstractArray}; kw...) = convert(T, Trace(collect(eachslice(t.x, dims=ndims(t.x)))); kw..., type=typeof(t), subtitle="size: $(size(t.x))") - -function Base.convert( - ::Type{Term.AbstractRenderable}, - t::Trace{<:AbstractVector}; - width=88, - n_head=2, - n_tail=1, - name="Trace", - style=TRACE_COLORS[mod1(hash(name), length(TRACE_COLORS))], - type=typeof(t), - subtitle="size: $(size(t.x))" -) - title = "$name: [italic]$type[/italic] " - min_width = min(width, length(title) - 4) - - n = length(t.x) - if n == 0 - content = "" - elseif 1 <= n <= n_head + n_tail - content = mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x) - else - content = mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x[1:n_head]) / - TextBox("...", justify=:center, width=min_width - 6) / - mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x[end-n_tail+1:end]) - end - Panel(content, width=min_width, title=title, subtitle=subtitle, subtitle_justify=:right, style=style, subtitle_style="yellow") -end - -function Base.convert(::Type{Term.AbstractRenderable}, t::Traces; width=88) - max_len = mapreduce(length, max, t.traces) - min_len = mapreduce(length, min, t.traces) - if max_len - min_len == 1 - n_tails = [length(x) == max_len ? 2 : 1 for x in t.traces] - else - n_tails = [1 for x in t.traces] - end - N = length(t.traces) - max_inner_width = ceil(Int, (width - 6 * 2) / N) - Panel( - mapreduce(((i, x),) -> convert(Term.AbstractRenderable, t[x]; width=max_inner_width, name=x, n_tail=n_tails[i], style=TRACE_COLORS[mod1(i, length(TRACE_COLORS))]), *, enumerate(keys(t))), - title="Traces", - style="yellow3", - subtitle="$N traces in total", - subtitle_justify=:right, - width=width, - fit=true - ) -end - -function Base.convert(::Type{Term.AbstractRenderable}, e::Episode; width=88) - Panel( - convert(Term.AbstractRenderable, e.traces; width=width - 6), - title="Episode", - style="green_yellow", - subtitle=e[] ? "Episode END" : "Episode growing...", - subtitle_justify=:right, - width=width, - fit=true - ) -end - -function Base.convert(::Type{Term.AbstractRenderable}, e::Episodes; width=88) - n = length(e) - if n == 0 - content = "" - elseif n == 1 - content = convert(Term.AbstractRenderable, e[1], width=width - 6) - elseif n == 2 - content = convert(Term.AbstractRenderable, e[1], width=width - 6) / - convert(Term.AbstractRenderable, e[end], width=width - 6) - else - content = convert(Term.AbstractRenderable, e[1], width=width - 6) / - TextBox("...", justify=:center, width=width - 6) / - convert(Term.AbstractRenderable, e[end], width=width - 6) - end - - Panel( - content, - title="Episodes", - subtitle="$n episodes in total", - subtitle_justify=:right, - width=width, - fit=true, - style="wheat1" - ) -end - -function Base.convert(r::Type{Term.AbstractRenderable}, t::Trajectory; width=88) - Panel( - convert(r, t.container; width=width - 8) / - Panel(convert(Term.Tree, t.sampler); title="sampler", style="yellow3", fit=true, width=width - 8) / - Panel(convert(Term.Tree, t.controler); title="controler", style="yellow3", fit=true, width=width - 8); - title="Trajectory", - style="yellow3", - width=width, - fit=true - ) -end - -# general converter - -Base.convert(::Type{Term.Tree}, x) = Tree(to_tree_body(x); title=to_tree_title(x)) -Base.convert(::Type{Term.Tree}, x::Tree) = x - -function to_tree_body(x) - pts = propertynames(x) - if length(pts) > 0 - Dict("$p => $(summary(getproperty(x, p)))" => to_tree_body(getproperty(x, p)) for p in pts) - else - x - end -end - -to_tree_title(x) = "$(summary(x))" \ No newline at end of file diff --git a/src/traces.jl b/src/traces.jl index ba28642..fd990bb 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -5,12 +5,7 @@ import MacroTools: @forward abstract type AbstractTraces{names,T} <: AbstractVector{NamedTuple{names,T}} end -# function Base.show(io::IO, ::MIME"text/plain", t::AbstractTraces{names}) where {names} -# println(io, "$(length(names)) traces in total with $(length(t)) elements:") -# for n in names -# println(" :$n => $(summary(t[n]))") -# end -# end +Base.keys(t::AbstractTraces{names}) where {names} = names """ Traces(;kw...) @@ -18,12 +13,6 @@ abstract type AbstractTraces{names,T} <: AbstractVector{NamedTuple{names,T}} end struct Traces{T,names,E} <: AbstractTraces{names,E} traces::T function Traces(; kw...) - for (k, v) in kw - if !(v isa AbstractVector) - throw(ArgumentError("the value of $k should be an AbstractVector")) - end - end - data = map(x -> convert(LastDimSlices, x), values(kw)) t = StructArray(data) new{typeof(t),keys(data),Tuple{typeof(data).types...}}(t) diff --git a/src/trajectory.jl b/src/trajectory.jl index 4e0bc70..6b9f656 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,7 +1,9 @@ -export Trajectory +export Trajectory, TrajectoryStyle, SyncTrajectoryStyle, AsyncTrajectoryStyle using Base.Threads +struct AsyncTrajectoryStyle end +struct SyncTrajectoryStyle end """ Trajectory(container, sampler, controler) @@ -53,6 +55,15 @@ Base.@kwdef struct Trajectory{C,S,T} end end +TrajectoryStyle(::Trajectory) = SyncTrajectoryStyle() +TrajectoryStyle(::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}) = AsyncTrajectoryStyle() + +Base.bind(::Trajectory, ::Task) = nothing + +function Base.bind(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, task) + bind(t.controler.ch_in, task) + bind(t.controler.ch_out, task) +end function Base.push!(t::Trajectory, x) n_pre = length(t.container) From 59536beb24fac4114b13e81863d04212614b9382 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 22 May 2022 21:20:31 +0800 Subject: [PATCH 04/17] fix ci --- src/Trajectories.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Trajectories.jl b/src/Trajectories.jl index f99d27d..8ef7ef9 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -6,7 +6,6 @@ include("episodes.jl") include("samplers.jl") include("controlers.jl") include("trajectory.jl") -include("rendering.jl") include("common/common.jl") end From 856d3d59d874e703d6437a342b3d8269ce6aff1e Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 22 May 2022 21:25:44 +0800 Subject: [PATCH 05/17] rename is_done => is_terminated --- src/episodes.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/episodes.jl b/src/episodes.jl index 86812c2..fef27ff 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -13,8 +13,8 @@ struct Episode{T,E} <: AbstractVector{E} end Base.getindex(e::Episode, I) = getindex(e.traces, I) -Base.getindex(e::Episode) = getindex(e.is_done) -Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_done, x) +Base.getindex(e::Episode) = getindex(e.is_terminated) +Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_terminated, x) Base.size(e::Episode) = size(e.traces) @@ -22,7 +22,7 @@ Episode(t::T) where {T<:AbstractTraces} = Episode{T,eltype(t)}(t, Ref(false)) for f in (:push!, :pushfirst!, :append!, :prepend!) @eval function Base.$f(t::Episode, x) - if t.is_done[] + if t.is_terminated[] throw(ArgumentError("The episode is already flagged as done!")) else $f(t.traces, x) @@ -32,14 +32,14 @@ end function Base.pop!(t::Episode) pop!(t.traces) - t.is_done[] = false + t.is_terminated[] = false end Base.popfirst!(t::Episode) = popfirst!(t.traces) function Base.empty!(t::Episode) empty!(t.traces) - t.is_done[] = false + t.is_terminated[] = false end ##### From f835470eec6eafa7fb53ffc7428420ea54f9f6ef Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 23 May 2022 11:41:07 +0800 Subject: [PATCH 06/17] remove LastDimSlices --- Project.toml | 1 - src/LastDimSlices.jl | 22 --------- src/Trajectories.jl | 1 - src/traces.jl | 106 +++++++++++++++++++++++++++---------------- test/traces.jl | 15 ++---- 5 files changed, 71 insertions(+), 74 deletions(-) diff --git a/Project.toml b/Project.toml index bc8814c..efa7a7b 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" [compat] diff --git a/src/LastDimSlices.jl b/src/LastDimSlices.jl index a7e0030..188597e 100644 --- a/src/LastDimSlices.jl +++ b/src/LastDimSlices.jl @@ -3,25 +3,3 @@ export LastDimSlices using MacroTools: @forward # See also https://github.com/JuliaLang/julia/pull/32310 - -struct LastDimSlices{T,E} <: AbstractVector{E} - parent::T -end - -function LastDimSlices(x::T) where {T<:AbstractArray} - E = eltype(x) - N = ndims(x) - 1 - P = typeof(x) - I = Tuple{ntuple(_ -> Base.Slice{Base.OneTo{Int}}, Val(ndims(x) - 1))...,Int} - LastDimSlices{T,SubArray{E,N,P,I,true}}(x) -end - -Base.convert(::Type{LastDimSlices}, x::AbstractVector) = x -Base.convert(::Type{LastDimSlices}, x::AbstractArray) = LastDimSlices(x) - -Base.size(x::LastDimSlices) = (size(x.parent, ndims(x.parent)),) -Base.getindex(s::LastDimSlices, I) = getindex(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) -Base.view(s::LastDimSlices, I) = view(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) -Base.setindex!(s::LastDimSlices, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) - -@forward LastDimSlices.parent Base.parent, Base.pushfirst!, Base.push!, Base.pop!, Base.append!, Base.prepend!, Base.empty! \ No newline at end of file diff --git a/src/Trajectories.jl b/src/Trajectories.jl index 8ef7ef9..10e2d2d 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -1,6 +1,5 @@ module Trajectories -include("LastDimSlices.jl") include("traces.jl") include("episodes.jl") include("samplers.jl") diff --git a/src/traces.jl b/src/traces.jl index fd990bb..c36d104 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -1,11 +1,44 @@ export Trace, Traces, MultiplexTraces -using StructArrays import MacroTools: @forward +##### + +struct Trace{T,E} <: AbstractVector{E} + parent::T +end + +function Trace(x::T) where {T<:AbstractArray} + E = eltype(x) + N = ndims(x) - 1 + P = typeof(x) + I = Tuple{ntuple(_ -> Base.Slice{Base.OneTo{Int}}, Val(ndims(x) - 1))...,Int} + Trace{T,SubArray{E,N,P,I,true}}(x) +end + +Base.convert(::Type{Trace}, x::AbstractArray) = Trace(x) + +Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),) +Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) +Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) + +@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty! + +##### + +""" +For each concrete `AbstractTraces`, we have the following assumption: + +1. Every inner trace is an `AbstractVector` +1. Support partial updating +1. Return *View* by default when getting elements. +""" abstract type AbstractTraces{names,T} <: AbstractVector{NamedTuple{names,T}} end Base.keys(t::AbstractTraces{names}) where {names} = names +Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names + +##### """ Traces(;kw...) @@ -13,17 +46,29 @@ Base.keys(t::AbstractTraces{names}) where {names} = names struct Traces{T,names,E} <: AbstractTraces{names,E} traces::T function Traces(; kw...) - data = map(x -> convert(LastDimSlices, x), values(kw)) - t = StructArray(data) - new{typeof(t),keys(data),Tuple{typeof(data).types...}}(t) + data = map(x -> convert(Trace, x), values(kw)) + new{typeof(data),keys(data),Tuple{typeof(data).types...}}(data) end end -@forward Traces.traces Base.size, Base.parent, Base.getindex, Base.setindex!, Base.view, Base.push!, Base.pushfirst!, Base.pop!, Base.popfirst!, Base.empty! +Base.getindex(t::Traces, s::Symbol) = getindex(t.traces, s) +Base.getindex(t::Traces, i) = map(x -> getindex(x, i), t.traces) + +@forward Traces.traces Base.parent -Base.append!(t::Traces, x::NamedTuple) = append!(t.traces, StructArray(x)) -Base.prepend!(t::Traces, x::NamedTuple) = prepend!(t.traces, StructArray(x)) -Base.getindex(t::Traces, s::Symbol) = getproperty(t.traces, s) +Base.size(t::Traces) = (mapreduce(length, min, t.traces),) + +for f in (:push!, :pushfirst!, :append!, :prepend!) + @eval function Base.$f(ts::Traces, xs::NamedTuple) + for (k, v) in pairs(xs) + $f(ts.traces[k], v) + end + end +end + +for f in (:pop!, :popfirst!, :empty!) + @eval Base.$f(ts::Traces) = map($f, ts.traces) +end ##### @@ -49,36 +94,29 @@ struct MultiplexTraces{names,T,E} <: AbstractTraces{names,Tuple{E,E}} trace::T end -function MultiplexTraces{names}(t) where {names} +function MultiplexTraces{names}(t::AbstractVector) where {names} if length(names) != 2 throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names")) end - trace = convert(LastDimSlices, t) + trace = convert(Trace, t) MultiplexTraces{names,typeof(trace),eltype(trace)}(trace) end function Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names} a, b = names if k == a - @view t.trace[1:end-1] + Trace(t.trace[1:end-1]) elseif k == b - @view t.trace[2:end] + Trace(t.trace[2:end]) else throw(ArgumentError("unknown trace name: $k")) end end -Base.getindex(t::MultiplexTraces{names}, I::Int) where {names} = NamedTuple{names}(t[k][I] for k in names) -Base.getindex(t::MultiplexTraces{names}, I) where {names} = StructArray(NamedTuple{names}(t[k][I] for k in names)) -Base.view(t::MultiplexTraces{names}, I) where {names} = StructArray(NamedTuple{names}(view(t[k], I) for k in names)) -Base.size(t::MultiplexTraces) = (max(0, length(t.trace) - 1),) +Base.getindex(t::MultiplexTraces{names}, I::Int) where {names} = NamedTuple{names}((t.trace[I], t.trace[I+1])) +Base.getindex(t::MultiplexTraces{names}, I::AbstractArray{Int}) where {names} = NamedTuple{names}((t.trace[I], t.trace[I.+1])) -function Base.setindex!(t::MultiplexTraces{names}, v::NamedTuple, i) where {names} - a, b = names - va, vb = getindex(v, a), getindex(v, b) - t.trace[i] = va - t.trace[i+1] = vb -end +Base.size(t::MultiplexTraces) = (max(0, length(t.trace) - 1),) @forward MultiplexTraces.trace Base.parent, Base.pop!, Base.popfirst!, Base.empty! @@ -106,14 +144,14 @@ function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::AbstractTraces{k2,T2}) where { ks = (k1..., k2...) ts = (t1, t2) inds = (; (k => 1 for k in k1)..., (k => 2 for k in k2)...) - MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) + MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) end function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::MergedTraces{k2,T,N,T2}) where {k1,T1,k2,T,N,T2} ks = (k1..., k2...) ts = (t1, t2.traces...) inds = merge(NamedTuple(k => 1 for k in k1), map(v => v + 1, t1.inds)) - MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) + MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) end @@ -121,34 +159,24 @@ function Base.:(+)(t1::MergedTraces{k1,T,N,T1}, t2::AbstractTraces{k2,T2}) where ks = (k1..., k2...) ts = (t1.traces..., t2) inds = merge(t1.inds, (; (k => length(ts) for k in k2)...)) - MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) + MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) end function Base.:(+)(t1::MergedTraces{k1,T1,N1,E1}, t2::MergedTraces{k2,T2,N2,E2}) where {k1,T1,N1,E1,k2,T2,N2,E2} ks = (k1..., k2...) ts = (t1.traces..., t2.traces...) inds = merge(t1.inds, map(x -> x + length(t1.traces), t2.inds)) - MergedTraces{ks,typeof(ts),length(k1) + length(k2),Tuple{T1.types...,T2.types...}}(ts, inds) + MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) end -Base.size(t::MergedTraces) = size(t.traces[1]) -Base.getindex(t::MergedTraces, I::Int) = mapreduce(x -> getindex(x, I), merge, t.traces) -Base.getindex(t::MergedTraces, I) = StructArray(mapreduce(x -> getfield(getindex(x, I), :components), merge, t.traces)) -Base.view(t::MergedTraces, I) = StructArray(mapreduce(x -> getfield(view(x, I), :components), merge, t.traces)) - -function Base.setindex!(t::MergedTraces, x::NamedTuple, I) - for (k, v) in pairs(x) - setindex!(t.traces[t.inds[k]], (; k => v), I) - end -end - +Base.size(t::MergedTraces) = (mapreduce(length, min, t.traces),) +Base.getindex(t::MergedTraces, I) = mapreduce(x -> getindex(x, I), merge, t.traces) for f in (:push!, :pushfirst!, :append!, :prepend!) @eval function Base.$f(ts::MergedTraces, xs::NamedTuple) for (k, v) in pairs(xs) - t = ts.traces[ts.inds[k]] - $f(t, (; k => v)) + $f(ts.traces[ts.inds[k]], (; k => v)) end end end diff --git a/test/traces.jl b/test/traces.jl index 593c1aa..ec81293 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -21,10 +21,7 @@ @test t_12.a == [1, 2] @test t_12.b == [false, true] - t_12.a[1] = 0 - @test t[:a][1] != 0 - - t_12_view = @view t[1:2] + t_12_view = t[1:2] t_12_view.a[1] = 0 @test t[:a][1] == 0 @@ -79,8 +76,8 @@ end @test t3[:a][1:3] == [1, 2, 3] - t3_view = @view t3[1:3] - t3_view.a[1] = 0 + t3_view = t3[1:3] + t3_view[:a][1] = 0 @test t3[:a][1] == 0 pop!(t3) @@ -109,11 +106,7 @@ end @test length(t8) == 3 - t8_view = @view t8[2:3] + t8_view = t8[2:3] t8_view.a[1] = 0 @test t8[:a][2] == 0 - - t8_slice = t8[2:3] - t8_slice.a[1] = -1 - @test t8[:a][2] != -1 end \ No newline at end of file From 50f25a12ee0ff7e74b5c10f8d5cdf441c7f15dd3 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 23 May 2022 20:37:45 +0800 Subject: [PATCH 07/17] update README --- README.md | 254 ++++++++++++++++++++++++++++++++++++++++++++++-- src/samplers.jl | 12 +-- src/traces.jl | 24 +++-- 3 files changed, 269 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index f650d21..1e6ba6f 100644 --- a/README.md +++ b/README.md @@ -6,23 +6,259 @@ ## Design -A typical example of `Trajectory`: +The relationship of several concepts provided in this package: -![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png) +``` +┌───────────────────────────────────┐ +│ Trajectory │ +│ ┌───────────────────────────────┐ │ +│ │ AbstractTraces │ │ +│ │ ┌───────────────┐ │ │ +│ │ :trace_A => │ AbstractTrace │ │ │ +│ │ └───────────────┘ │ │ +│ │ │ │ +│ │ ┌───────────────┐ │ │ +│ │ :trace_B => │ AbstractTrace │ │ │ +│ │ └───────────────┘ │ │ +│ │ ... ... │ │ +│ └───────────────────────────────┘ │ +│ ┌───────────┐ │ +│ │ Sampler │ │ +│ └───────────┘ │ +│ ┌────────────┐ │ +│ │ Controller │ │ +│ └────────────┘ │ +└───────────────────────────────────┘ +``` + +## `Trajectory` + +A `Trajectory` contains 3 parts: -Exported APIs are: +- A `container` to store data. (Usually an `AbstractTraces`) +- A `sampler` to determine how to sample a batch from `container` +- A `controller` to decide when to sample a new batch from the `container` + +Typical usage: ```julia -push!(trajectory; [trace_name=value]...) -append!(trajectory; [trace_name=value]...) +julia> t = Trajectory(Traces(a=Int[], b=Bool[]), BatchSampler(3), InsertSampleRatioControler(1.0, 3)); + +julia> for i in 1:5 + push!(t, (a=i, b=iseven(i))) + end -for sample in trajectory - # consume samples from the trajectory -end +julia> for batch in t + println(batch) + end +(a = [4, 5, 1], b = Bool[1, 0, 0]) +(a = [3, 2, 4], b = Bool[0, 1, 1]) +(a = [4, 1, 2], b = Bool[1, 0, 1]) ``` -A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc. +### `AbstractTrace` + +`Trace` is the most commonly used `AbstractTrace`. It provides a sequential view on other containers. + +```julia +julia> t = Trace([1,2,3]) +3-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: + 1 + 2 + 3 +julia> push!(t, 4) +4-element Vector{Int64}: + 1 + 2 + 3 + 4 + +julia> append!(t, 5:6) +6-element Vector{Int64}: + 1 + 2 + 3 + 4 + 5 + 6 + +julia> pop!(t) +6 + +julia> popfirst!(t) +1 + +julia> t +4-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: + 2 + 3 + 4 + 5 + +julia> empty!(t) +Int64[] + +julia> t +0-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} +``` + +In most cases, it's just the same with a `Vector`. + +When an `AbstractArray` with higher dimension provided, it is **slice**d along the last dimension to provide a sequential view. + +```julia +julia> t = Trace(rand(2,3)) +3-element Trace{Matrix{Float64}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}: + [0.276012181224494, 0.6621365818458671] + [0.9937726056924112, 0.3308302850028162] + [0.9856543000075456, 0.6123660950650406] + +julia> t[1] +2-element view(::Matrix{Float64}, :, 1) with eltype Float64: + 0.276012181224494 + 0.6621365818458671 + +julia> t[1] = [0., 1.] +2-element Vector{Float64}: + 0.0 + 1.0 + +julia> t +3-element Trace{Matrix{Float64}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}: + [0.0, 1.0] + [0.9937726056924112, 0.3308302850028162] + [0.9856543000075456, 0.6123660950650406] + +julia> t[[2,3,1]] +2×3 view(::Matrix{Float64}, :, [2, 3, 1]) with eltype Float64: + 0.993773 0.985654 0.0 + 0.33083 0.612366 1.0 +``` + +**Note** that when indexing a `Trace`, a **view** is returned. As you can see above, the data is modified in-place. + +### `AbstractTraces` + +`Traces` is one of the common `AbstractTraces`. It is similar to a `NamedTuple` of several traces. + +```julia +julia> t = Traces(; + a=[1, 2], + b=Bool[0, 1] + ) # note that `a` and `b` are converted into `Trace` implicitly +Traces with 2 traces: + :a => 2-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} + :b => 2-element Trace{Vector{Bool}, SubArray{Bool, 0, Vector{Bool}, Tuple{Int64}, true}} + +julia> push!(t, (a=3, b=false)) + +julia> t +Traces with 2 traces: + :a => 3-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} + :b => 3-element Trace{Vector{Bool}, SubArray{Bool, 0, Vector{Bool}, Tuple{Int64}, true}} + + +julia> t[:a] +3-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: + 1 + 2 + 3 + +julia> t[:b] +3-element Trace{Vector{Bool}, SubArray{Bool, 0, Vector{Bool}, Tuple{Int64}, true}}: + false + true + false + +julia> t[1] +(a = 1, b = false) + +julia> t[1:3] +(a = [1, 2, 3], b = Bool[0, 1, 0]) +``` + +Another commonly used traces is `MultiplexTraces`. In reinforcement learning, *states* and *next-states* share most data except for the first and last element. + +```julia +julia> t = MultiplexTraces{(:state, :next_state)}([1,2,3]); + +julia> t[:state] +2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 1 + 2 + +julia> t[:next_state] +2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 2 + 3 + +julia> push!(t, (;state=4)) +4-element Vector{Int64}: + 1 + 2 + 3 + 4 + +julia> t[:state] +3-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 1 + 2 + 3 + +julia> t[:next_state] +3-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 2 + 3 + 4 + +julia> length(t) +3 +``` + +Note that different kinds of `AbstractTraces` can be combined to form a `MergedTraces`. + +``` +ulia> t1 = Traces(a=Int[]) + t2 = MultiplexTraces{(:b, :c)}(Int[]) + t3 = t1 + t2 +MergedTraces with 3 traces: + :a => 0-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} + :b => 0-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}} + :c => 0-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}} + + +julia> push!(t3, (a=1,b=2,c=3)) + +julia> t3[:a] +1-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: + 1 + +julia> t3[:b] +1-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 2 + +julia> t3[:c] +1-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 3 + +julia> push!(t3, (a=-1, b=-2)) + +julia> t3[:a] +2-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: + 1 + -1 + +julia> t3[:b] +2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 2 + 3 + +julia> t3[:c] +2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: + 3 + -2 +``` ## Acknowledgement This async version is mainly inspired by [deepmind/reverb](https://github.com/deepmind/reverb). diff --git a/src/samplers.jl b/src/samplers.jl index f1afb25..3877da9 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -17,14 +17,14 @@ Uniformly sample a batch of examples for each trace. See also [`sample`](@ref). """ -BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, transformer) +BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=batch) = BatchSampler(batch_size, rng, transformer) function sample(s::BatchSampler, t::AbstractTraces) inds = rand(s.rng, 1:length(t), s.batch_size) - @view t[inds] + map(s.transformer, t[inds]) end -function sample(s::BatchSampler, e::Episodes) - inds = rand(s.rng, 1:length(t), s.batch_size) - batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer -end \ No newline at end of file +# function sample(s::BatchSampler, e::Episodes) +# inds = rand(s.rng, 1:length(t), s.batch_size) +# [s.episodes[e.inds[i][1]][e.inds[i][2]] for i in inds] |> s.transformer +# end \ No newline at end of file diff --git a/src/traces.jl b/src/traces.jl index c36d104..ce7ace3 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -4,7 +4,11 @@ import MacroTools: @forward ##### -struct Trace{T,E} <: AbstractVector{E} +abstract type AbstractTrace{E} <: AbstractVector{E} end + +Base.convert(::Type{AbstractTrace}, x::AbstractTrace) = x + +struct Trace{T,E} <: AbstractTrace{E} parent::T end @@ -16,7 +20,7 @@ function Trace(x::T) where {T<:AbstractArray} Trace{T,SubArray{E,N,P,I,true}}(x) end -Base.convert(::Type{Trace}, x::AbstractArray) = Trace(x) +Base.convert(::Type{AbstractTrace}, x::AbstractArray) = Trace(x) Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),) Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) @@ -35,6 +39,14 @@ For each concrete `AbstractTraces`, we have the following assumption: """ abstract type AbstractTraces{names,T} <: AbstractVector{NamedTuple{names,T}} end +function Base.show(io::IO, ::MIME"text/plain", t::AbstractTraces{names,T}) where {names,T} + s = nameof(typeof(t)) + println(io, "$s with $(length(names)) traces:") + for n in names + println(io, " :$n => $(summary(t[n]))") + end +end + Base.keys(t::AbstractTraces{names}) where {names} = names Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names @@ -46,7 +58,7 @@ Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names struct Traces{T,names,E} <: AbstractTraces{names,E} traces::T function Traces(; kw...) - data = map(x -> convert(Trace, x), values(kw)) + data = map(x -> convert(AbstractTrace, x), values(kw)) new{typeof(data),keys(data),Tuple{typeof(data).types...}}(data) end end @@ -98,16 +110,16 @@ function MultiplexTraces{names}(t::AbstractVector) where {names} if length(names) != 2 throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names")) end - trace = convert(Trace, t) + trace = convert(AbstractTrace, t) MultiplexTraces{names,typeof(trace),eltype(trace)}(trace) end function Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names} a, b = names if k == a - Trace(t.trace[1:end-1]) + convert(AbstractTrace, t.trace[1:end-1]) elseif k == b - Trace(t.trace[2:end]) + convert(AbstractTrace, t.trace[2:end]) else throw(ArgumentError("unknown trace name: $k")) end From addd2b7ae90fff0e275facea1871f35b42baf249 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 23 May 2022 21:36:51 +0800 Subject: [PATCH 08/17] add more tests --- src/traces.jl | 2 +- test/common.jl | 64 ++++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 test/common.jl diff --git a/src/traces.jl b/src/traces.jl index ce7ace3..47d4124 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -106,7 +106,7 @@ struct MultiplexTraces{names,T,E} <: AbstractTraces{names,Tuple{E,E}} trace::T end -function MultiplexTraces{names}(t::AbstractVector) where {names} +function MultiplexTraces{names}(t) where {names} if length(names) != 2 throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names")) end diff --git a/test/common.jl b/test/common.jl new file mode 100644 index 0000000..36eb8c9 --- /dev/null +++ b/test/common.jl @@ -0,0 +1,64 @@ +@testset "CircularArraySARTTraces" begin + t = CircularArraySARTTraces(; + capacity=3, + state=Float32 => (2, 3), + action=Float32 => (2), + reward=Float32 => (), + terminal=Bool => () + ) + + push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2))) + @test length(t) == 0 + + push!(t, (reward=1.0f0, terminal=false)) + @test length(t) == 0 # next_state and next_action is still missing + + push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2)) + @test length(t) == 1 + + @test t[1] == ( + state=ones(Float32, 2, 3), + next_state=ones(Float32, 2, 3) * 2, + action=ones(Float32, 2), + next_action=ones(Float32, 2) * 2, + reward=1.0f0, + terminal=false, + ) + + push!(t, (reward=2.0f0, terminal=false)) + push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3)) + + @test length(t) == 2 + + push!(t, (reward=3.0f0, terminal=false)) + push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4)) + + @test length(t) == 3 + + push!(t, (reward=4.0f0, terminal=false)) + push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5)) + + @test length(t) == 3 + @test t[1] == ( + state=ones(Float32, 2, 3) * 2, + next_state=ones(Float32, 2, 3) * 3, + action=ones(Float32, 2) * 2, + next_action=ones(Float32, 2) * 3, + reward=2.0f0, + terminal=false, + ) + @test t[end] == ( + state=ones(Float32, 2, 3) * 4, + next_state=ones(Float32, 2, 3) * 5, + action=ones(Float32, 2) * 4, + next_action=ones(Float32, 2) * 5, + reward=4.0f0, + terminal=false, + ) + + batch = t[1:3] + @test size(batch.state) == (2, 3, 3) + @test size(batch.action) == (2, 3) + @test batch.reward == [2.0, 3.0, 4.0] + @test batch.terminal == Bool[0, 0, 0] +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 91680af..d1a2a21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,9 @@ using Trajectories +using CircularArrayBuffers using Test @testset "Trajectories.jl" begin include("traces.jl") + include("common.jl") include("trajectories.jl") end From aca0601557d78ce2098d01372ccb31157ab31ffd Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 24 May 2022 11:15:23 +0800 Subject: [PATCH 09/17] unify Traces and MergedTraces --- src/episodes.jl | 4 ++- src/traces.jl | 86 +++++++++++++++++++++--------------------------- test/episodes.jl | 5 +++ test/runtests.jl | 1 + 4 files changed, 46 insertions(+), 50 deletions(-) create mode 100644 test/episodes.jl diff --git a/src/episodes.jl b/src/episodes.jl index fef27ff..69965ab 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -20,7 +20,7 @@ Base.size(e::Episode) = size(e.traces) Episode(t::T) where {T<:AbstractTraces} = Episode{T,eltype(t)}(t, Ref(false)) -for f in (:push!, :pushfirst!, :append!, :prepend!) +for f in (:push!, :append!) @eval function Base.$f(t::Episode, x) if t.is_terminated[] throw(ArgumentError("The episode is already flagged as done!")) @@ -35,6 +35,8 @@ function Base.pop!(t::Episode) t.is_terminated[] = false end +Base.pushfirst!(t::Episode, x) = pushfirst!(t.traces, x) +Base.prepend!(t::Episode, x) = prepend!(t.traces, x) Base.popfirst!(t::Episode) = popfirst!(t.traces) function Base.empty!(t::Episode) diff --git a/src/traces.jl b/src/traces.jl index 47d4124..e28ec6d 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -52,38 +52,6 @@ Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names ##### -""" - Traces(;kw...) -""" -struct Traces{T,names,E} <: AbstractTraces{names,E} - traces::T - function Traces(; kw...) - data = map(x -> convert(AbstractTrace, x), values(kw)) - new{typeof(data),keys(data),Tuple{typeof(data).types...}}(data) - end -end - -Base.getindex(t::Traces, s::Symbol) = getindex(t.traces, s) -Base.getindex(t::Traces, i) = map(x -> getindex(x, i), t.traces) - -@forward Traces.traces Base.parent - -Base.size(t::Traces) = (mapreduce(length, min, t.traces),) - -for f in (:push!, :pushfirst!, :append!, :prepend!) - @eval function Base.$f(ts::Traces, xs::NamedTuple) - for (k, v) in pairs(xs) - $f(ts.traces[k], v) - end - end -end - -for f in (:pop!, :popfirst!, :empty!) - @eval Base.$f(ts::Traces) = map($f, ts.traces) -end - -##### - """ MultiplexTraces{names}(trace) @@ -144,57 +112,77 @@ for f in (:push!, :pushfirst!, :append!, :prepend!) end ##### - -struct MergedTraces{names,T,N,E} <: AbstractTraces{names,E} +struct Traces{names,T,N,E} <: AbstractTraces{names,E} traces::T inds::NamedTuple{names,NTuple{N,Int}} end -Base.getindex(ts::MergedTraces, s::Symbol) = ts.traces[ts.inds[s]][s] + +function Traces(; kw...) + data = map(x -> convert(AbstractTrace, x), values(kw)) + names = keys(data) + inds = NamedTuple(k => i for (i, k) in enumerate(names)) + Traces{names,typeof(data),length(names),typeof(values(data))}(data, inds) +end + + +function Base.getindex(ts::Traces, s::Symbol) + t = ts.traces[ts.inds[s]] + if t isa AbstractTrace + t + else + t[s] + end +end + +Base.getindex(t::Traces{names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names)) function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::AbstractTraces{k2,T2}) where {k1,k2,T1,T2} ks = (k1..., k2...) ts = (t1, t2) inds = (; (k => 1 for k in k1)..., (k => 2 for k in k2)...) - MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) + Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) end -function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::MergedTraces{k2,T,N,T2}) where {k1,T1,k2,T,N,T2} +function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::Traces{k2,T,N,T2}) where {k1,T1,k2,T,N,T2} ks = (k1..., k2...) ts = (t1, t2.traces...) - inds = merge(NamedTuple(k => 1 for k in k1), map(v => v + 1, t1.inds)) - MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) + inds = merge(NamedTuple(k => 1 for k in k1), map(v -> v + 1, t2.inds)) + Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) end -function Base.:(+)(t1::MergedTraces{k1,T,N,T1}, t2::AbstractTraces{k2,T2}) where {k1,T,N,T1,k2,T2} +function Base.:(+)(t1::Traces{k1,T,N,T1}, t2::AbstractTraces{k2,T2}) where {k1,T,N,T1,k2,T2} ks = (k1..., k2...) ts = (t1.traces..., t2) inds = merge(t1.inds, (; (k => length(ts) for k in k2)...)) - MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) + Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) end -function Base.:(+)(t1::MergedTraces{k1,T1,N1,E1}, t2::MergedTraces{k2,T2,N2,E2}) where {k1,T1,N1,E1,k2,T2,N2,E2} +function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T1,N1,E1,k2,T2,N2,E2} ks = (k1..., k2...) ts = (t1.traces..., t2.traces...) inds = merge(t1.inds, map(x -> x + length(t1.traces), t2.inds)) - MergedTraces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) + Traces{ks,typeof(ts),length(ks),Tuple{E1.types...,E2.types...}}(ts, inds) end - -Base.size(t::MergedTraces) = (mapreduce(length, min, t.traces),) -Base.getindex(t::MergedTraces, I) = mapreduce(x -> getindex(x, I), merge, t.traces) +Base.size(t::Traces) = (mapreduce(length, min, t.traces),) for f in (:push!, :pushfirst!, :append!, :prepend!) - @eval function Base.$f(ts::MergedTraces, xs::NamedTuple) + @eval function Base.$f(ts::Traces, xs::NamedTuple) for (k, v) in pairs(xs) - $f(ts.traces[ts.inds[k]], (; k => v)) + t = ts.traces[ts.inds[k]] + if t isa AbstractTrace + $f(t, v) + else + $f(t, (; k => v)) + end end end end for f in (:pop!, :popfirst!, :empty!) - @eval function Base.$f(ts::MergedTraces) + @eval function Base.$f(ts::Traces) for t in ts.traces $f(t) end diff --git a/test/episodes.jl b/test/episodes.jl new file mode 100644 index 0000000..13afae0 --- /dev/null +++ b/test/episodes.jl @@ -0,0 +1,5 @@ +@testset "Episode" begin + e = Episode( + Traces() + ) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d1a2a21..ac22051 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Test @testset "Trajectories.jl" begin include("traces.jl") + include("episodes.jl") include("common.jl") include("trajectories.jl") end From 84463c0484d76e8222cf33b1f74d8b0b33f9c4ac Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 24 May 2022 11:31:15 +0800 Subject: [PATCH 10/17] unify Traces and MergedTraces --- src/traces.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/traces.jl b/src/traces.jl index e28ec6d..9cd3b56 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -8,10 +8,15 @@ abstract type AbstractTrace{E} <: AbstractVector{E} end Base.convert(::Type{AbstractTrace}, x::AbstractTrace) = x +Base.summary(io::IO, t::AbstractTrace) = print(io, "$(length(t))-element $(nameof(typeof(t)))") + +##### struct Trace{T,E} <: AbstractTrace{E} parent::T end +Base.summary(io::IO, t::Trace{T}) where {T} = print(io, "$(length(t))-element $(nameof(typeof(t))){$T}") + function Trace(x::T) where {T<:AbstractArray} E = eltype(x) N = ndims(x) - 1 @@ -41,7 +46,7 @@ abstract type AbstractTraces{names,T} <: AbstractVector{NamedTuple{names,T}} end function Base.show(io::IO, ::MIME"text/plain", t::AbstractTraces{names,T}) where {names,T} s = nameof(typeof(t)) - println(io, "$s with $(length(names)) traces:") + println(io, "$s with $(length(names)) entries:") for n in names println(io, " :$n => $(summary(t[n]))") end From 95180bedf9b6c51055a851a0e611e72aec656b9f Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 24 May 2022 21:42:45 +0800 Subject: [PATCH 11/17] add more tests --- Project.toml | 1 + src/Trajectories.jl | 1 - src/episodes.jl | 94 -------------------------------- src/traces.jl | 127 +++++++++++++++++++++++++++++++++++++++++++- test/episodes.jl | 5 -- test/runtests.jl | 1 - test/traces.jl | 71 +++++++++++++++++++++++++ 7 files changed, 198 insertions(+), 102 deletions(-) delete mode 100644 src/episodes.jl delete mode 100644 test/episodes.jl diff --git a/Project.toml b/Project.toml index efa7a7b..921e065 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" [compat] diff --git a/src/Trajectories.jl b/src/Trajectories.jl index 10e2d2d..00cf79a 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -1,7 +1,6 @@ module Trajectories include("traces.jl") -include("episodes.jl") include("samplers.jl") include("controlers.jl") include("trajectory.jl") diff --git a/src/episodes.jl b/src/episodes.jl deleted file mode 100644 index 69965ab..0000000 --- a/src/episodes.jl +++ /dev/null @@ -1,94 +0,0 @@ -export Episode, Episodes - - -""" - Episode(traces) - -An `Episode` is a wrapper around [`Traces`](@ref). You can use `(e::Episode)[]` -to check/update whether the episode reaches a terminal or not. -""" -struct Episode{T,E} <: AbstractVector{E} - traces::T - is_terminated::Ref{Bool} -end - -Base.getindex(e::Episode, I) = getindex(e.traces, I) -Base.getindex(e::Episode) = getindex(e.is_terminated) -Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_terminated, x) - -Base.size(e::Episode) = size(e.traces) - -Episode(t::T) where {T<:AbstractTraces} = Episode{T,eltype(t)}(t, Ref(false)) - -for f in (:push!, :append!) - @eval function Base.$f(t::Episode, x) - if t.is_terminated[] - throw(ArgumentError("The episode is already flagged as done!")) - else - $f(t.traces, x) - end - end -end - -function Base.pop!(t::Episode) - pop!(t.traces) - t.is_terminated[] = false -end - -Base.pushfirst!(t::Episode, x) = pushfirst!(t.traces, x) -Base.prepend!(t::Episode, x) = prepend!(t.traces, x) -Base.popfirst!(t::Episode) = popfirst!(t.traces) - -function Base.empty!(t::Episode) - empty!(t.traces) - t.is_terminated[] = false -end - -##### - -""" - Episodes(init) - -A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an [`Episode`](@ref). -""" -struct Episodes <: AbstractVector{Episode} - init::Any - episodes::Vector{Episode} - inds::Vector{Tuple{Int,Int}} -end - -Base.size(e::Episodes) = size(e.inds) -Base.getindex(e::Episodes, I) = getindex(e.episodes, I) - -function Base.push!(e::Episodes, x::Episode) - push!(e.episodes, x) - for i in 1:length(x) - push!(e.inds, (length(e.episodes), i)) - end -end - -function Base.append!(e::Episodes, xs::AbstractVector{<:Episode}) - for x in xs - push!(e, x) - end -end - -function Base.push!(e::Episodes, x) - if isempty(e.episodes) || e.episodes[end][] - episode = e.init() - push!(episode, x) - push!(e.episodes, episode) - else - push!(e.episodes[end], x) - push!(e.inds, (length(e.episodes), length(e.episodes[end]))) - end -end - -function Base.append!(e::Episodes, x) - n_pre = length(e.episodes[end]) - append!(e.episodes[end], x) - n_post = length(e.episodes[end]) - for i in n_pre:n_post - push!(e.inds, (lengthe.episodes, i)) - end -end diff --git a/src/traces.jl b/src/traces.jl index 9cd3b56..62bfc4e 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -1,6 +1,7 @@ -export Trace, Traces, MultiplexTraces +export Trace, Traces, MultiplexTraces, Episode, Episodes import MacroTools: @forward +import StackViews: StackView ##### @@ -116,6 +117,130 @@ for f in (:push!, :pushfirst!, :append!, :prepend!) end end +##### + +""" + Episode(traces) + +An `Episode` is a wrapper around [`Traces`](@ref). You can use `(e::Episode)[]` +to check/update whether the episode reaches a terminal or not. +""" +struct Episode{T,names,E} <: AbstractTraces{names,E} + traces::T + is_terminated::Ref{Bool} +end + +Episode(t::AbstractTraces{names,T}) where {names,T} = Episode{typeof(t),names,T}(t, Ref(false)) + +@forward Episode.traces Base.getindex, Base.setindex!, Base.size + +Base.getindex(e::Episode) = getindex(e.is_terminated) +Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_terminated, x) + +for f in (:push!, :append!) + @eval function Base.$f(t::Episode, x) + if t.is_terminated[] + throw(ArgumentError("The episode is already flagged as done!")) + else + $f(t.traces, x) + end + end +end + +function Base.pop!(t::Episode) + pop!(t.traces) + t.is_terminated[] = false +end + +Base.pushfirst!(t::Episode, x) = pushfirst!(t.traces, x) +Base.prepend!(t::Episode, x) = prepend!(t.traces, x) +Base.popfirst!(t::Episode) = popfirst!(t.traces) + +function Base.empty!(t::Episode) + empty!(t.traces) + t.is_terminated[] = false +end + +##### + +""" + Episodes(init) + +A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an empty [`Episode`](@ref). +""" +struct Episodes{names,E} <: AbstractTraces{names,E} + init::Any + episodes::Vector{Episode} + inds::Vector{Tuple{Int,Int}} +end + +function Episodes(init) + x = init() + @assert x isa Episode + @assert length(x) == 0 + names, E = eltype(x).parameters + Episodes{names,E}(init, [x], Tuple{Int,Int}[]) +end + +Base.size(e::Episodes) = size(e.inds) + +Base.setindex!(e::Episodes, is_terminated::Bool) = setindex!(e.episodes[end], is_terminated) + +Base.getindex(e::Episodes) = getindex(e.episodes[end]) + +function Base.getindex(e::Episodes, I::Int) + i, j = e.inds[I] + e.episodes[i][j] +end + +function Base.getindex(e::Episodes{names}, I) where {names} + NamedTuple{names}( + StackView( + map(I) do i + x, y = e.inds[i] + e.episodes[x][n][y] + end + ) + for n in names + ) +end + +function Base.getindex(e::Episodes, I::Symbol) + @warn "The returned trace is a vector of partitions instead of a continuous view" maxlog = 1 + map(x -> x[I], e.episodes) +end + +function Base.push!(e::Episodes, x::Episode) + # !!! note we do not check whether the last Episode is terminated or not here + push!(e.episodes, x) + for i in 1:length(x) + push!(e.inds, (length(e.episodes), i)) + end +end + +function Base.append!(e::Episodes, xs::AbstractVector{<:Episode}) + # !!! note we do not check whether each Episode is terminated or not here + for x in xs + push!(e, x) + end +end + +function Base.push!(e::Episodes, x::NamedTuple) + if isempty(e.episodes) || e.episodes[end][] + episode = e.init() + push!(episode, x) + push!(e, episode) + else + n_pre = length(e.episodes[end]) + push!(e.episodes[end], x) + n_post = length(e.episodes[end]) + # this is to support partial inserting + if n_post - n_pre == 1 + push!(e.inds, (length(e.episodes), length(e.episodes[end]))) + end + end +end + ##### struct Traces{names,T,N,E} <: AbstractTraces{names,E} traces::T diff --git a/test/episodes.jl b/test/episodes.jl deleted file mode 100644 index 13afae0..0000000 --- a/test/episodes.jl +++ /dev/null @@ -1,5 +0,0 @@ -@testset "Episode" begin - e = Episode( - Traces() - ) -end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ac22051..d1a2a21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,6 @@ using Test @testset "Trajectories.jl" begin include("traces.jl") - include("episodes.jl") include("common.jl") include("trajectories.jl") end diff --git a/test/traces.jl b/test/traces.jl index ec81293..82380b4 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -109,4 +109,75 @@ end t8_view = t8[2:3] t8_view.a[1] = 0 @test t8[:a][2] == 0 +end + +@testset "Episode" begin + t = Episode( + Traces( + state=Int[], + action=Float64[] + ) + ) + + @test length(t) == 0 + + push!(t, (state=1, action=1.0)) + @test length(t) == 1 + + append!(t, (state=[2, 3], action=[2.0, 3.0])) + @test length(t) == 3 + + @test t[:state] == [1, 2, 3] + @test t[end-1:end] == ( + state=[2, 3], + action=[2.0, 3.0] + ) + + t[] = true # seal + @test_throws ArgumentError push!(t, (state=4, action=4.0)) + + pop!(t) + @test length(t) == 2 + + push!(t, (state=4, action=4.0)) + @test length(t) == 3 + + t[] = true # seal + empty!(t) + + @test length(t) == 0 +end + +@testset "Episodes" begin + t = Episodes() do + Episode(Traces(state=Float64[], action=Int[])) + end + + @test length(t) == 0 + + push!(t, (state=1.0, action=1)) + + @test length(t) == 1 + @test t[1] == (state=1.0, action=1) + + t[] = true # seal + + push!(t, (state=2.0, action=2)) + @test length(t) == 2 + + @test t[end] == (state=2.0, action=2) + + # https://github.com/JuliaArrays/StackViews.jl/issues/3 + @test_broken t[1:2] == (state=[1.0, 2.0], action=[1, 2]) + + push!(t, (state=3.0, action=3)) + t[] = true # seal + + @test_broken size(t[:state]) == (3,) + + push!(t, Episode(Traces(state=[4.0, 5.0, 6.0], action=[4, 5, 6]))) + @test t[] == false + + t[] = true + @test length(t) == 6 end \ No newline at end of file From 1c72ba7fae19e0728a870e07fc4e865d0acc487d Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 25 May 2022 12:51:08 +0800 Subject: [PATCH 12/17] add more tesets --- src/LastDimSlices.jl | 5 ----- src/Trajectories.jl | 2 ++ src/patch.jl | 3 +++ src/samplers.jl | 5 ----- test/runtests.jl | 1 + test/samplers.jl | 25 +++++++++++++++++++++++++ 6 files changed, 31 insertions(+), 10 deletions(-) delete mode 100644 src/LastDimSlices.jl create mode 100644 src/patch.jl create mode 100644 test/samplers.jl diff --git a/src/LastDimSlices.jl b/src/LastDimSlices.jl deleted file mode 100644 index 188597e..0000000 --- a/src/LastDimSlices.jl +++ /dev/null @@ -1,5 +0,0 @@ -export LastDimSlices - -using MacroTools: @forward - -# See also https://github.com/JuliaLang/julia/pull/32310 diff --git a/src/Trajectories.jl b/src/Trajectories.jl index 00cf79a..7e825cc 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -1,5 +1,7 @@ module Trajectories +include("patch.jl") + include("traces.jl") include("samplers.jl") include("controlers.jl") diff --git a/src/patch.jl b/src/patch.jl new file mode 100644 index 0000000..9b08b8f --- /dev/null +++ b/src/patch.jl @@ -0,0 +1,3 @@ +import MLUtils + +MLUtils.batch(x::AbstractArray{<:Number}) = x \ No newline at end of file diff --git a/src/samplers.jl b/src/samplers.jl index 3877da9..a7a492d 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -23,8 +23,3 @@ function sample(s::BatchSampler, t::AbstractTraces) inds = rand(s.rng, 1:length(t), s.batch_size) map(s.transformer, t[inds]) end - -# function sample(s::BatchSampler, e::Episodes) -# inds = rand(s.rng, 1:length(t), s.batch_size) -# [s.episodes[e.inds[i][1]][e.inds[i][2]] for i in inds] |> s.transformer -# end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d1a2a21..ca4776f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using Test @testset "Trajectories.jl" begin include("traces.jl") include("common.jl") + include("samplers.jl") include("trajectories.jl") end diff --git a/test/samplers.jl b/test/samplers.jl new file mode 100644 index 0000000..824654e --- /dev/null +++ b/test/samplers.jl @@ -0,0 +1,25 @@ +@testset "BatchSampler" begin + sz = 32 + s = BatchSampler(sz) + t = Traces( + state=rand(3, 4, 5), + action=rand(1:4, 5), + ) + + b = Trajectories.sample(s, t) + + @test keys(b) == (:state, :action) + @test size(b.state) == (3, 4, sz) + @test size(b.action) == (sz,) + + e = Episodes() do + Episode(Traces(state=rand(2, 3, 0), action=rand(0))) + end + + push!(e, Episode(Traces(state=rand(2, 3, 2), action=rand(2)))) + push!(e, Episode(Traces(state=rand(2, 3, 3), action=rand(3)))) + + @test length(e) == 5 + @test size(e[2:4].state) == (2, 3, 3) + @test_broken size(e[2:4].action) == (3,) +end \ No newline at end of file From b1a5c22b4189c79913954fd778dda8011263d47c Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 25 May 2022 13:10:00 +0800 Subject: [PATCH 13/17] add alias for common traces --- src/common/CircularArraySARTTraces.jl | 14 ++++++++++++-- src/common/CircularArraySLARTTraces.jl | 17 ++++++++++++++--- src/common/common.jl | 7 +++++++ test/common.jl | 17 ++++++++++++++++- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/common/CircularArraySARTTraces.jl b/src/common/CircularArraySARTTraces.jl index 0d77547..f29093a 100644 --- a/src/common/CircularArraySARTTraces.jl +++ b/src/common/CircularArraySARTTraces.jl @@ -1,5 +1,15 @@ export CircularArraySARTTraces +const CircularArraySARTTraces = Traces{ + SSAART, + <:Tuple{ + <:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}}, + <:Trace{<:CircularArrayBuffer}, + <:Trace{<:CircularArrayBuffer}, + } +} + function CircularArraySARTTraces(; capacity::Int, state=Int => (), @@ -12,8 +22,8 @@ function CircularArraySARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + - MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + + MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index 66e3a2f..121168b 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -1,5 +1,16 @@ export CircularArraySLARTTraces +const CircularArraySLARTTraces = Traces{ + SSLLAART, + <:Tuple{ + <:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}}, + <:Trace{<:CircularArrayBuffer}, + <:Trace{<:CircularArrayBuffer}, + } +} + function CircularArraySLARTTraces(; capacity::Int, state=Int => (), @@ -14,9 +25,9 @@ function CircularArraySLARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + - MultiplexTraces{(:legal_actions_mask, :next_legal_actions_mask)}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + - MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + + MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + + MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), diff --git a/src/common/common.jl b/src/common/common.jl index 4377083..ef647dc 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -1,5 +1,12 @@ using CircularArrayBuffers +const SS = (:state, :next_state) +const LL = (:legal_actions_mask, :next_legal_actions_mask) +const AA = (:action, :next_action) +const RT = (:reward, :terminal) +const SSAART = (SS..., AA..., RT...) +const SSLLAART = (SS..., LL..., AA..., RT...) + include("sum_tree.jl") include("CircularArraySARTTraces.jl") include("CircularArraySLARTTraces.jl") diff --git a/test/common.jl b/test/common.jl index 36eb8c9..091d587 100644 --- a/test/common.jl +++ b/test/common.jl @@ -2,11 +2,13 @@ t = CircularArraySARTTraces(; capacity=3, state=Float32 => (2, 3), - action=Float32 => (2), + action=Float32 => (2,), reward=Float32 => (), terminal=Bool => () ) + @test t isa CircularArraySARTTraces + push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2))) @test length(t) == 0 @@ -61,4 +63,17 @@ @test size(batch.action) == (2, 3) @test batch.reward == [2.0, 3.0, 4.0] @test batch.terminal == Bool[0, 0, 0] +end + +@testset "CircularArraySLARTTraces" begin + t = CircularArraySLARTTraces(; + capacity=3, + state=Float32 => (2, 3), + legal_actions_mask=Bool => (5,), + action=Int => (), + reward=Float32 => (), + terminal=Bool => () + ) + + @test t isa CircularArraySLARTTraces end \ No newline at end of file From 344b3b61ddc4dbb3931fae4e7703c7deef0fb7be Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 25 May 2022 13:27:11 +0800 Subject: [PATCH 14/17] move tests of sumtree here --- README.md | 207 +++---------------------------------------------- test/common.jl | 26 +++++++ 2 files changed, 38 insertions(+), 195 deletions(-) diff --git a/README.md b/README.md index 1e6ba6f..cb15a05 100644 --- a/README.md +++ b/README.md @@ -56,209 +56,26 @@ julia> for batch in t (a = [4, 1, 2], b = Bool[1, 0, 1]) ``` -### `AbstractTrace` +**Traces** -`Trace` is the most commonly used `AbstractTrace`. It provides a sequential view on other containers. +- `Traces` +- `MultiplexTraces` +- `CircularSARTTraces` +- `Episode` +- `Episodes` -```julia -julia> t = Trace([1,2,3]) -3-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: - 1 - 2 - 3 -julia> push!(t, 4) -4-element Vector{Int64}: - 1 - 2 - 3 - 4 - -julia> append!(t, 5:6) -6-element Vector{Int64}: - 1 - 2 - 3 - 4 - 5 - 6 - -julia> pop!(t) -6 - -julia> popfirst!(t) -1 - -julia> t -4-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: - 2 - 3 - 4 - 5 - -julia> empty!(t) -Int64[] - -julia> t -0-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} -``` - -In most cases, it's just the same with a `Vector`. - -When an `AbstractArray` with higher dimension provided, it is **slice**d along the last dimension to provide a sequential view. - -```julia -julia> t = Trace(rand(2,3)) -3-element Trace{Matrix{Float64}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}: - [0.276012181224494, 0.6621365818458671] - [0.9937726056924112, 0.3308302850028162] - [0.9856543000075456, 0.6123660950650406] - -julia> t[1] -2-element view(::Matrix{Float64}, :, 1) with eltype Float64: - 0.276012181224494 - 0.6621365818458671 - -julia> t[1] = [0., 1.] -2-element Vector{Float64}: - 0.0 - 1.0 - -julia> t -3-element Trace{Matrix{Float64}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}: - [0.0, 1.0] - [0.9937726056924112, 0.3308302850028162] - [0.9856543000075456, 0.6123660950650406] - -julia> t[[2,3,1]] -2×3 view(::Matrix{Float64}, :, [2, 3, 1]) with eltype Float64: - 0.993773 0.985654 0.0 - 0.33083 0.612366 1.0 -``` - -**Note** that when indexing a `Trace`, a **view** is returned. As you can see above, the data is modified in-place. - -### `AbstractTraces` - -`Traces` is one of the common `AbstractTraces`. It is similar to a `NamedTuple` of several traces. - -```julia -julia> t = Traces(; - a=[1, 2], - b=Bool[0, 1] - ) # note that `a` and `b` are converted into `Trace` implicitly -Traces with 2 traces: - :a => 2-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} - :b => 2-element Trace{Vector{Bool}, SubArray{Bool, 0, Vector{Bool}, Tuple{Int64}, true}} - - -julia> push!(t, (a=3, b=false)) +**Samplers** -julia> t -Traces with 2 traces: - :a => 3-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} - :b => 3-element Trace{Vector{Bool}, SubArray{Bool, 0, Vector{Bool}, Tuple{Int64}, true}} +- `BatchSampler` +**Controllers** -julia> t[:a] -3-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: - 1 - 2 - 3 +- `InsertSampleRatioController` +- `AsyncInsertSampleRatioController` -julia> t[:b] -3-element Trace{Vector{Bool}, SubArray{Bool, 0, Vector{Bool}, Tuple{Int64}, true}}: - false - true - false -julia> t[1] -(a = 1, b = false) +Please refer tests for common usage. (TODO: generate docs and add links to above data structures) -julia> t[1:3] -(a = [1, 2, 3], b = Bool[0, 1, 0]) -``` - -Another commonly used traces is `MultiplexTraces`. In reinforcement learning, *states* and *next-states* share most data except for the first and last element. - -```julia -julia> t = MultiplexTraces{(:state, :next_state)}([1,2,3]); - -julia> t[:state] -2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 1 - 2 - -julia> t[:next_state] -2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 2 - 3 - -julia> push!(t, (;state=4)) -4-element Vector{Int64}: - 1 - 2 - 3 - 4 - -julia> t[:state] -3-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 1 - 2 - 3 - -julia> t[:next_state] -3-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 2 - 3 - 4 - -julia> length(t) -3 -``` - -Note that different kinds of `AbstractTraces` can be combined to form a `MergedTraces`. - -``` -ulia> t1 = Traces(a=Int[]) - t2 = MultiplexTraces{(:b, :c)}(Int[]) - t3 = t1 + t2 -MergedTraces with 3 traces: - :a => 0-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}} - :b => 0-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}} - :c => 0-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}} - - -julia> push!(t3, (a=1,b=2,c=3)) - -julia> t3[:a] -1-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: - 1 - -julia> t3[:b] -1-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 2 - -julia> t3[:c] -1-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 3 - -julia> push!(t3, (a=-1, b=-2)) - -julia> t3[:a] -2-element Trace{Vector{Int64}, SubArray{Int64, 0, Vector{Int64}, Tuple{Int64}, true}}: - 1 - -1 - -julia> t3[:b] -2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 2 - 3 - -julia> t3[:c] -2-element Trace{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, SubArray{Int64, 0, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Tuple{Int64}, true}}: - 3 - -2 -``` ## Acknowledgement This async version is mainly inspired by [deepmind/reverb](https://github.com/deepmind/reverb). diff --git a/test/common.jl b/test/common.jl index 091d587..444c3fc 100644 --- a/test/common.jl +++ b/test/common.jl @@ -1,3 +1,29 @@ +@testset "sum_tree" begin + t = SumTree(8) + + for i in 1:4 + push!(t, i) + end + + @test length(t) == 4 + @test size(t) == (4,) + + for i in 5:16 + push!(t, i) + end + + @test length(t) == 8 + @test size(t) == (8,) + @test t == 9:16 + + t[:] .= 1 + @test t == ones(8) + @test all([get(t, v)[1] == i for (i, v) in enumerate(0.5:1.0:8)]) + + empty!(t) + @test length(t) == 0 +end + @testset "CircularArraySARTTraces" begin t = CircularArraySARTTraces(; capacity=3, From ab664e96e9c7425e1513455cd86b9c9b2136e3fb Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 25 May 2022 18:41:51 +0800 Subject: [PATCH 15/17] resolve comment: add parameter of inner Episode in Episodes --- src/traces.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/traces.jl b/src/traces.jl index 62bfc4e..fd9f20a 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -168,18 +168,19 @@ end A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an empty [`Episode`](@ref). """ -struct Episodes{names,E} <: AbstractTraces{names,E} +struct Episodes{names,E,T} <: AbstractTraces{names,E} init::Any - episodes::Vector{Episode} + episodes::Vector{T} inds::Vector{Tuple{Int,Int}} end function Episodes(init) x = init() + T = typeof(x) @assert x isa Episode @assert length(x) == 0 names, E = eltype(x).parameters - Episodes{names,E}(init, [x], Tuple{Int,Int}[]) + Episodes{names,E,T}(init, [x], Tuple{Int,Int}[]) end Base.size(e::Episodes) = size(e.inds) From 87b9cf39c88e409479035f746686ac0423f2e296 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 25 May 2022 18:42:43 +0800 Subject: [PATCH 16/17] Update README.md Co-authored-by: Henri Dehaybe <47037088+HenriDeh@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index cb15a05..4955d8c 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ julia> for batch in t **Controllers** - `InsertSampleRatioController` +- `InsertSampleController` - `AsyncInsertSampleRatioController` From 162d4b9d1be0f8be9331f2a28d72d0484b5eda99 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 25 May 2022 18:43:00 +0800 Subject: [PATCH 17/17] Update README.md Co-authored-by: Henri Dehaybe <47037088+HenriDeh@users.noreply.github.com> --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 4955d8c..e72d5dc 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,8 @@ julia> for batch in t **Samplers** - `BatchSampler` +- `MetaSampler` +- `MultiBatchSampler` **Controllers**