Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional callbacks to the collectLatest and combineLatestUpdates #51

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions src/observable/collected.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ export collectLatest
import Base: show

"""
collectLatest(sources::S, mappingFn::F = copy) where { S, F }
collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy)
collectLatest(sources::S, mappingFn::F = copy, callbackFn::C = nothing)
collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy, callbackFn::C = nothing)

Collects values from multible Observables and emits it in one single array every time each inner Observable has a new value.
Reemits errors from inner observables. Completes when all inner observables completes.

# Arguments
- `sources`: input sources
- `mappingFn`: optional mappingFn applied to an array of emited values, `copy` by default, should return a Vector
- `callbackFn`: optional callback function, which is called right after `mappingFn` has been evaluated, accepts the state of the inner actor and the computed value, `nothing` by default

Note: `collectLatest` completes immediately if `sources` are empty.

Expand All @@ -37,17 +38,17 @@ subscribe!(collected, logger())

See also: [`Subscribable`](@ref), [`subscribe!`](@ref), [`combineLatest`](@ref)
"""
function collectLatest(sources::S, mappingFn::F = copy) where { S, F }
function collectLatest(sources::S, mappingFn::F = copy, callbackFn::C = nothing) where { S, F, C }
T = union_type(sources)
R = similar_typeof(sources, T)
return CollectLatestObservable{T, S, R, F}(sources, mappingFn)
return CollectLatestObservable{T, S, R, F, C}(sources, mappingFn, callbackFn)
end

collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy) where { T, R, S, F } = CollectLatestObservable{T, S, R, F}(sources, mappingFn)
collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy, callbackFn::C = nothing) where { T, R, S, F, C } = CollectLatestObservable{T, S, R, F, C}(sources, mappingFn, callbackFn)

##

struct CollectLatestObservableWrapper{L, A, S, B, T, F}
struct CollectLatestObservableWrapper{L, A, S, B, T, F, C}
actor :: A
storage :: S

Expand All @@ -56,25 +57,30 @@ struct CollectLatestObservableWrapper{L, A, S, B, T, F}
ustatus :: B # Updates status
subscriptions :: T
mappingFn :: F
callbackFn :: C

CollectLatestObservableWrapper{L, A, S, B, T, F}(actor::A, storage::S, cstatus::B, vstatus::B, ustatus::B, subscriptions::T, mappingFn::F) where {L, A, S, B, T, F} = begin
return new(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn)
CollectLatestObservableWrapper{L, A, S, B, T, F, C}(actor::A, storage::S, cstatus::B, vstatus::B, ustatus::B, subscriptions::T, mappingFn::F, callbackFn::C) where {L, A, S, B, T, F, C} = begin
return new(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn, callbackFn)
end
end

function CollectLatestObservableWrapper(::Type{L}, actor::A, storage::S, mappingFn::F) where { L, A, S, F }
function CollectLatestObservableWrapper(::Type{L}, actor::A, storage::S, mappingFn::F, callbackFn::C) where { L, A, S, F, C }
nsize = size(storage)
cstatus = falses(nsize)
vstatus = falses(nsize)
ustatus = falses(nsize)
subscriptions = fill!(similar(storage, Teardown), voidTeardown)
return CollectLatestObservableWrapper{L, A, S, typeof(cstatus), typeof(subscriptions), F}(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn)
return CollectLatestObservableWrapper{L, A, S, typeof(cstatus), typeof(subscriptions), F, C}(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn, callbackFn)
end

cstatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.cstatus[index]
vstatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.vstatus[index]
ustatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.ustatus[index]

fill_cstatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.cstatus, value)
fill_vstatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.vstatus, value)
fill_ustatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.ustatus, value)

dispose(wrapper::CollectLatestObservableWrapper) = begin fill!(wrapper.cstatus, true); foreach(s -> unsubscribe!(s), wrapper.subscriptions) end

struct CollectLatestObservableInnerActor{L, I <: CartesianIndex, W} <: Actor{L}
Expand All @@ -94,7 +100,11 @@ function next_received!(wrapper::CollectLatestObservableWrapper, data, index::Ca
@inbounds wrapper.ustatus[index] = true
if all(wrapper.vstatus) && !all(wrapper.cstatus)
unsafe_copyto!(wrapper.vstatus, 1, wrapper.cstatus, 1, length(wrapper.vstatus))
next!(wrapper.actor, wrapper.mappingFn(wrapper.storage))
value = wrapper.mappingFn(wrapper.storage)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to parameterize CollectLatestObservableWrapper here? I think it will compile away the call to isnothing if {C} is nothing already

Copy link
Member Author

@bvdmitri bvdmitri Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is parametrized with C already

next!(wrapper.actor, value)
if !isnothing(wrapper.callbackFn)
wrapper.callbackFn(wrapper, value)
end
end
end

Expand All @@ -120,15 +130,16 @@ end

##

@subscribable struct CollectLatestObservable{T, S, R, F} <: Subscribable{R}
sources :: S
mappingFn :: F
@subscribable struct CollectLatestObservable{T, S, R, F, C} <: Subscribable{R}
sources :: S
mappingFn :: F
callbackFn :: C
end

function on_subscribe!(observable::CollectLatestObservable{L}, actor::A) where { L, A }
sources = observable.sources
storage = similar(sources, L)
wrapper = CollectLatestObservableWrapper(L, actor, storage, observable.mappingFn)
wrapper = CollectLatestObservableWrapper(L, actor, storage, observable.mappingFn, observable.callbackFn)
W = typeof(wrapper)

if length(sources) !== 0
Expand Down
36 changes: 24 additions & 12 deletions src/observable/combined_updates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ See also: [`Subscribable`](@ref), [`subscribe!`](@ref), [`PushEach`](@ref), [`Pu
"""
function combineLatestUpdates end

combineLatestUpdates(; strategy = PushEach()) = error("combineLatestUpdates operator expects at least one inner observable on input")
combineLatestUpdates(args...; strategy = PushEach()) = combineLatestUpdates(args, strategy)
combineLatestUpdates(sources::S, strategy::G = PushEach()) where { S <: Tuple, G } = CombineLatestUpdatesObservable{S, G}(sources, strategy)
combineLatestUpdates(; strategy = PushEach()) = error("combineLatestUpdates operator expects at least one inner observable on input")
combineLatestUpdates(args...; strategy = PushEach()) = combineLatestUpdates(args, strategy)
combineLatestUpdates(sources::S, strategy::G = PushEach(), ::Type{R} = S, mappingFn::F = identity, callbackFn::C = nothing) where { S <: Tuple, R, G, F, C } = CombineLatestUpdatesObservable{R, S, G, F, C}(sources, strategy, mappingFn, callbackFn)

##

Expand All @@ -37,32 +37,42 @@ on_complete!(actor::CombineLatestUpdatesInnerActor{L, W}) where { L, W } =

##

struct CombineLatestUpdatesActorWrapper{S, A, G, U}
struct CombineLatestUpdatesActorWrapper{S, A, G, U, F, C}
sources :: S
actor :: A
nsize :: Int
strategy :: G # Push update strategy
updates :: U # Updates
subscriptions :: Vector{Teardown}
mappingFn :: F
callbackFn :: C
end

function CombineLatestUpdatesActorWrapper(sources::S, actor::A, strategy::G) where { S, A, G }
function CombineLatestUpdatesActorWrapper(sources::S, actor::A, strategy::G, mappingFn::F, callbackFn::C) where { S, A, G, F, C }
updates = getustorage(S)
nsize = length(sources)
subscriptions = fill!(Vector{Teardown}(undef, nsize), voidTeardown)
return CombineLatestUpdatesActorWrapper(sources, actor, nsize, strategy, updates, subscriptions)
return CombineLatestUpdatesActorWrapper(sources, actor, nsize, strategy, updates, subscriptions, mappingFn, callbackFn)
end

push_update!(wrapper::CombineLatestUpdatesActorWrapper) = push_update!(wrapper.nsize, wrapper.updates, wrapper.strategy)

dispose(wrapper::CombineLatestUpdatesActorWrapper) = begin fill_cstatus!(wrapper.updates, true); foreach(s -> unsubscribe!(s), wrapper.subscriptions) end

fill_cstatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_cstatus!(wrapper.updates, value)
fill_vstatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_vstatus!(wrapper.updates, value)
fill_ustatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_ustatus!(wrapper.updates, value)

function next_received!(wrapper::CombineLatestUpdatesActorWrapper, data, index::Int)
vstatus!(wrapper.updates, index, true)
ustatus!(wrapper.updates, index, true)
if all_vstatus(wrapper.updates) && !all_cstatus(wrapper.updates)
push_update!(wrapper)
next!(wrapper.actor, wrapper.sources)
value = wrapper.mappingFn(wrapper.sources)
next!(wrapper.actor, value)
if !isnothing(wrapper.callbackFn)
wrapper.callbackFn(wrapper, value)
end
end
end

Expand All @@ -88,15 +98,17 @@ end

##

@subscribable struct CombineLatestUpdatesObservable{S, G} <: Subscribable{S}
sources :: S
strategy :: G
@subscribable struct CombineLatestUpdatesObservable{R, S, G, F, C} <: Subscribable{R}
sources :: S
strategy :: G
mappingFn :: F
callbackFn :: C
end

getrecent(observable::CombineLatestUpdatesObservable) = getrecent(observable.sources)

function on_subscribe!(observable::CombineLatestUpdatesObservable{S, G}, actor::A) where { S, G, A }
wrapper = CombineLatestUpdatesActorWrapper(observable.sources, actor, observable.strategy)
function on_subscribe!(observable::CombineLatestUpdatesObservable, actor)
wrapper = CombineLatestUpdatesActorWrapper(observable.sources, actor, observable.strategy, observable.mappingFn, observable.callbackFn)

__combine_latest_updates_unrolled_fill_subscriptions!(observable.sources, wrapper)

Expand Down
46 changes: 46 additions & 0 deletions test/observable/test_observable_collect_latest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,52 @@ include("../test_helpers.jl")
unsubscribe!(subscription)
end

@testset begin
source1 = Subject(Int)
source2 = Subject(Int)

callbackCalled = []
callbackFn = (wrapper, value) -> begin
# We reset the state of the `vstatus`
if isequal(value, "2")
Rocket.fill_vstatus!(wrapper, true)
push!(callbackCalled, true)
else
push!(callbackCalled, false)
end
end

combined = collectLatest(Int, String, [ source1, source2 ], (values) -> string(sum(values)), callbackFn)
values = []
subscription = subscribe!(combined, (value) -> push!(values, value))

@test values == []
@test callbackCalled == []
next!(source1, 0)
@test values == []
@test callbackCalled == []
next!(source2, 0)
@test values == ["0"]
@test callbackCalled == [false]

next!(source1, 1)
@test values == ["0"]
@test callbackCalled == [false]
next!(source2, 1)
@test values == ["0", "2"]
@test callbackCalled == [false, true]

next!(source1, 2)
@test values == ["0", "2", "3"] # this is hapenning because the callback should have been called
@test callbackCalled == [false, true, false]
next!(source1, 2)
@test values == ["0", "2", "3"]
@test callbackCalled == [false, true, false]
next!(source2, 2)
@test values == ["0", "2", "3", "4"]
@test callbackCalled == [false, true, false, false]
end

end

end
46 changes: 46 additions & 0 deletions test/observable/test_observable_combine_updates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,52 @@ include("../test_helpers.jl")
unsubscribe!(subscription)
end

@testset begin
source1 = RecentSubject(Int)
source2 = RecentSubject(Int)

callbackCalled = []
callbackFn = (wrapper, value) -> begin
# We reset the state of the `vstatus`
if isequal(value, "2")
Rocket.fill_vstatus!(wrapper, true)
push!(callbackCalled, true)
else
push!(callbackCalled, false)
end
end

combined = combineLatestUpdates((source1, source2), PushNew(), String, (sources) -> string(sum(Rocket.getrecent.(sources))), callbackFn)
values = []
subscription = subscribe!(combined, (value) -> push!(values, value))

@test values == []
@test callbackCalled == []
next!(source1, 0)
@test values == []
@test callbackCalled == []
next!(source2, 0)
@test values == ["0"]
@test callbackCalled == [false]

next!(source1, 1)
@test values == ["0"]
@test callbackCalled == [false]
next!(source2, 1)
@test values == ["0", "2"]
@test callbackCalled == [false, true]

next!(source1, 2)
@test values == ["0", "2", "3"] # this is hapenning because the callback should have been called
@test callbackCalled == [false, true, false]
next!(source1, 2)
@test values == ["0", "2", "3"]
@test callbackCalled == [false, true, false]
next!(source2, 2)
@test values == ["0", "2", "3", "4"]
@test callbackCalled == [false, true, false, false]
end

end

end
Loading