From 336af9cd9745030647d96ffdcea787bbc84f7d3e Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Wed, 10 Apr 2024 16:10:14 +0200 Subject: [PATCH] Add optional callbacks to the collectLatest and combineLatestUpdates functions --- src/observable/collected.jl | 41 +++++++++++------ src/observable/combined_updates.jl | 36 ++++++++++----- .../test_observable_collect_latest.jl | 46 +++++++++++++++++++ .../test_observable_combine_updates.jl | 46 +++++++++++++++++++ 4 files changed, 142 insertions(+), 27 deletions(-) diff --git a/src/observable/collected.jl b/src/observable/collected.jl index 4925fb14f..9dc017391 100644 --- a/src/observable/collected.jl +++ b/src/observable/collected.jl @@ -3,8 +3,8 @@ export collectLatest import Base: show """ - collectLatest(sources::S, mappingFn::F = copy) where { S, F } - collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy) + collectLatest(sources::S, mappingFn::F = copy, callbackFn::C = nothing) + collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy, callbackFn::C = nothing) Collects values from multible Observables and emits it in one single array every time each inner Observable has a new value. Reemits errors from inner observables. Completes when all inner observables completes. @@ -12,6 +12,7 @@ Reemits errors from inner observables. Completes when all inner observables comp # Arguments - `sources`: input sources - `mappingFn`: optional mappingFn applied to an array of emited values, `copy` by default, should return a Vector +- `callbackFn`: optional callback function, which is called right after `mappingFn` has been evaluated, accepts the state of the inner actor and the computed value, `nothing` by default Note: `collectLatest` completes immediately if `sources` are empty. @@ -37,17 +38,17 @@ subscribe!(collected, logger()) See also: [`Subscribable`](@ref), [`subscribe!`](@ref), [`combineLatest`](@ref) """ -function collectLatest(sources::S, mappingFn::F = copy) where { S, F } +function collectLatest(sources::S, mappingFn::F = copy, callbackFn::C = nothing) where { S, F, C } T = union_type(sources) R = similar_typeof(sources, T) - return CollectLatestObservable{T, S, R, F}(sources, mappingFn) + return CollectLatestObservable{T, S, R, F, C}(sources, mappingFn, callbackFn) end -collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy) where { T, R, S, F } = CollectLatestObservable{T, S, R, F}(sources, mappingFn) +collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy, callbackFn::C = nothing) where { T, R, S, F, C } = CollectLatestObservable{T, S, R, F, C}(sources, mappingFn, callbackFn) ## -struct CollectLatestObservableWrapper{L, A, S, B, T, F} +struct CollectLatestObservableWrapper{L, A, S, B, T, F, C} actor :: A storage :: S @@ -56,25 +57,30 @@ struct CollectLatestObservableWrapper{L, A, S, B, T, F} ustatus :: B # Updates status subscriptions :: T mappingFn :: F + callbackFn :: C - CollectLatestObservableWrapper{L, A, S, B, T, F}(actor::A, storage::S, cstatus::B, vstatus::B, ustatus::B, subscriptions::T, mappingFn::F) where {L, A, S, B, T, F} = begin - return new(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn) + CollectLatestObservableWrapper{L, A, S, B, T, F, C}(actor::A, storage::S, cstatus::B, vstatus::B, ustatus::B, subscriptions::T, mappingFn::F, callbackFn::C) where {L, A, S, B, T, F, C} = begin + return new(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn, callbackFn) end end -function CollectLatestObservableWrapper(::Type{L}, actor::A, storage::S, mappingFn::F) where { L, A, S, F } +function CollectLatestObservableWrapper(::Type{L}, actor::A, storage::S, mappingFn::F, callbackFn::C) where { L, A, S, F, C } nsize = size(storage) cstatus = falses(nsize) vstatus = falses(nsize) ustatus = falses(nsize) subscriptions = fill!(similar(storage, Teardown), voidTeardown) - return CollectLatestObservableWrapper{L, A, S, typeof(cstatus), typeof(subscriptions), F}(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn) + return CollectLatestObservableWrapper{L, A, S, typeof(cstatus), typeof(subscriptions), F, C}(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn, callbackFn) end cstatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.cstatus[index] vstatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.vstatus[index] ustatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.ustatus[index] +fill_cstatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.cstatus, value) +fill_vstatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.vstatus, value) +fill_ustatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.ustatus, value) + dispose(wrapper::CollectLatestObservableWrapper) = begin fill!(wrapper.cstatus, true); foreach(s -> unsubscribe!(s), wrapper.subscriptions) end struct CollectLatestObservableInnerActor{L, I <: CartesianIndex, W} <: Actor{L} @@ -94,7 +100,11 @@ function next_received!(wrapper::CollectLatestObservableWrapper, data, index::Ca @inbounds wrapper.ustatus[index] = true if all(wrapper.vstatus) && !all(wrapper.cstatus) unsafe_copyto!(wrapper.vstatus, 1, wrapper.cstatus, 1, length(wrapper.vstatus)) - next!(wrapper.actor, wrapper.mappingFn(wrapper.storage)) + value = wrapper.mappingFn(wrapper.storage) + next!(wrapper.actor, value) + if !isnothing(wrapper.callbackFn) + wrapper.callbackFn(wrapper, value) + end end end @@ -120,15 +130,16 @@ end ## -@subscribable struct CollectLatestObservable{T, S, R, F} <: Subscribable{R} - sources :: S - mappingFn :: F +@subscribable struct CollectLatestObservable{T, S, R, F, C} <: Subscribable{R} + sources :: S + mappingFn :: F + callbackFn :: C end function on_subscribe!(observable::CollectLatestObservable{L}, actor::A) where { L, A } sources = observable.sources storage = similar(sources, L) - wrapper = CollectLatestObservableWrapper(L, actor, storage, observable.mappingFn) + wrapper = CollectLatestObservableWrapper(L, actor, storage, observable.mappingFn, observable.callbackFn) W = typeof(wrapper) if length(sources) !== 0 diff --git a/src/observable/combined_updates.jl b/src/observable/combined_updates.jl index 1e97d8de6..04e97ab02 100644 --- a/src/observable/combined_updates.jl +++ b/src/observable/combined_updates.jl @@ -18,9 +18,9 @@ See also: [`Subscribable`](@ref), [`subscribe!`](@ref), [`PushEach`](@ref), [`Pu """ function combineLatestUpdates end -combineLatestUpdates(; strategy = PushEach()) = error("combineLatestUpdates operator expects at least one inner observable on input") -combineLatestUpdates(args...; strategy = PushEach()) = combineLatestUpdates(args, strategy) -combineLatestUpdates(sources::S, strategy::G = PushEach()) where { S <: Tuple, G } = CombineLatestUpdatesObservable{S, G}(sources, strategy) +combineLatestUpdates(; strategy = PushEach()) = error("combineLatestUpdates operator expects at least one inner observable on input") +combineLatestUpdates(args...; strategy = PushEach()) = combineLatestUpdates(args, strategy) +combineLatestUpdates(sources::S, strategy::G = PushEach(), ::Type{R} = S, mappingFn::F = identity, callbackFn::C = nothing) where { S <: Tuple, R, G, F, C } = CombineLatestUpdatesObservable{R, S, G, F, C}(sources, strategy, mappingFn, callbackFn) ## @@ -37,32 +37,42 @@ on_complete!(actor::CombineLatestUpdatesInnerActor{L, W}) where { L, W } = ## -struct CombineLatestUpdatesActorWrapper{S, A, G, U} +struct CombineLatestUpdatesActorWrapper{S, A, G, U, F, C} sources :: S actor :: A nsize :: Int strategy :: G # Push update strategy updates :: U # Updates subscriptions :: Vector{Teardown} + mappingFn :: F + callbackFn :: C end -function CombineLatestUpdatesActorWrapper(sources::S, actor::A, strategy::G) where { S, A, G } +function CombineLatestUpdatesActorWrapper(sources::S, actor::A, strategy::G, mappingFn::F, callbackFn::C) where { S, A, G, F, C } updates = getustorage(S) nsize = length(sources) subscriptions = fill!(Vector{Teardown}(undef, nsize), voidTeardown) - return CombineLatestUpdatesActorWrapper(sources, actor, nsize, strategy, updates, subscriptions) + return CombineLatestUpdatesActorWrapper(sources, actor, nsize, strategy, updates, subscriptions, mappingFn, callbackFn) end push_update!(wrapper::CombineLatestUpdatesActorWrapper) = push_update!(wrapper.nsize, wrapper.updates, wrapper.strategy) dispose(wrapper::CombineLatestUpdatesActorWrapper) = begin fill_cstatus!(wrapper.updates, true); foreach(s -> unsubscribe!(s), wrapper.subscriptions) end +fill_cstatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_cstatus!(wrapper.updates, value) +fill_vstatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_vstatus!(wrapper.updates, value) +fill_ustatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_ustatus!(wrapper.updates, value) + function next_received!(wrapper::CombineLatestUpdatesActorWrapper, data, index::Int) vstatus!(wrapper.updates, index, true) ustatus!(wrapper.updates, index, true) if all_vstatus(wrapper.updates) && !all_cstatus(wrapper.updates) push_update!(wrapper) - next!(wrapper.actor, wrapper.sources) + value = wrapper.mappingFn(wrapper.sources) + next!(wrapper.actor, value) + if !isnothing(wrapper.callbackFn) + wrapper.callbackFn(wrapper, value) + end end end @@ -88,15 +98,17 @@ end ## -@subscribable struct CombineLatestUpdatesObservable{S, G} <: Subscribable{S} - sources :: S - strategy :: G +@subscribable struct CombineLatestUpdatesObservable{R, S, G, F, C} <: Subscribable{R} + sources :: S + strategy :: G + mappingFn :: F + callbackFn :: C end getrecent(observable::CombineLatestUpdatesObservable) = getrecent(observable.sources) -function on_subscribe!(observable::CombineLatestUpdatesObservable{S, G}, actor::A) where { S, G, A } - wrapper = CombineLatestUpdatesActorWrapper(observable.sources, actor, observable.strategy) +function on_subscribe!(observable::CombineLatestUpdatesObservable, actor) + wrapper = CombineLatestUpdatesActorWrapper(observable.sources, actor, observable.strategy, observable.mappingFn, observable.callbackFn) __combine_latest_updates_unrolled_fill_subscriptions!(observable.sources, wrapper) diff --git a/test/observable/test_observable_collect_latest.jl b/test/observable/test_observable_collect_latest.jl index ac6a312ba..ed42ead69 100644 --- a/test/observable/test_observable_collect_latest.jl +++ b/test/observable/test_observable_collect_latest.jl @@ -159,6 +159,52 @@ include("../test_helpers.jl") unsubscribe!(subscription) end + @testset begin + source1 = Subject(Int) + source2 = Subject(Int) + + callbackCalled = [] + callbackFn = (wrapper, value) -> begin + # We reset the state of the `vstatus` + if isequal(value, "2") + Rocket.fill_vstatus!(wrapper, true) + push!(callbackCalled, true) + else + push!(callbackCalled, false) + end + end + + combined = collectLatest(Int, String, [ source1, source2 ], (values) -> string(sum(values)), callbackFn) + values = [] + subscription = subscribe!(combined, (value) -> push!(values, value)) + + @test values == [] + @test callbackCalled == [] + next!(source1, 0) + @test values == [] + @test callbackCalled == [] + next!(source2, 0) + @test values == ["0"] + @test callbackCalled == [false] + + next!(source1, 1) + @test values == ["0"] + @test callbackCalled == [false] + next!(source2, 1) + @test values == ["0", "2"] + @test callbackCalled == [false, true] + + next!(source1, 2) + @test values == ["0", "2", "3"] # this is hapenning because the callback should have been called + @test callbackCalled == [false, true, false] + next!(source1, 2) + @test values == ["0", "2", "3"] + @test callbackCalled == [false, true, false] + next!(source2, 2) + @test values == ["0", "2", "3", "4"] + @test callbackCalled == [false, true, false, false] + end + end end diff --git a/test/observable/test_observable_combine_updates.jl b/test/observable/test_observable_combine_updates.jl index f4e008904..ef2ab462d 100644 --- a/test/observable/test_observable_combine_updates.jl +++ b/test/observable/test_observable_combine_updates.jl @@ -346,6 +346,52 @@ include("../test_helpers.jl") unsubscribe!(subscription) end + @testset begin + source1 = RecentSubject(Int) + source2 = RecentSubject(Int) + + callbackCalled = [] + callbackFn = (wrapper, value) -> begin + # We reset the state of the `vstatus` + if isequal(value, "2") + Rocket.fill_vstatus!(wrapper, true) + push!(callbackCalled, true) + else + push!(callbackCalled, false) + end + end + + combined = combineLatestUpdates((source1, source2), PushNew(), String, (sources) -> string(sum(Rocket.getrecent.(sources))), callbackFn) + values = [] + subscription = subscribe!(combined, (value) -> push!(values, value)) + + @test values == [] + @test callbackCalled == [] + next!(source1, 0) + @test values == [] + @test callbackCalled == [] + next!(source2, 0) + @test values == ["0"] + @test callbackCalled == [false] + + next!(source1, 1) + @test values == ["0"] + @test callbackCalled == [false] + next!(source2, 1) + @test values == ["0", "2"] + @test callbackCalled == [false, true] + + next!(source1, 2) + @test values == ["0", "2", "3"] # this is hapenning because the callback should have been called + @test callbackCalled == [false, true, false] + next!(source1, 2) + @test values == ["0", "2", "3"] + @test callbackCalled == [false, true, false] + next!(source2, 2) + @test values == ["0", "2", "3", "4"] + @test callbackCalled == [false, true, false, false] + end + end end