diff --git a/examples/transitions/src/main.rs b/examples/transitions/src/main.rs index 4396cf75d..3e3e895e7 100644 --- a/examples/transitions/src/main.rs +++ b/examples/transitions/src/main.rs @@ -37,7 +37,7 @@ async fn Child(cx: Scope<'_>, tab: Tab) -> View { fn App(cx: Scope) -> View { let tab = create_signal(cx, Tab::One); let transition = use_transition(cx); - let update = move |x| transition.start(move || tab.set(x)); + let update = move |x| transition.start(move || tab.set(x), || ()); view! { cx, div { diff --git a/packages/sycamore/src/suspense.rs b/packages/sycamore/src/suspense.rs index 31d035608..05f83850d 100644 --- a/packages/sycamore/src/suspense.rs +++ b/packages/sycamore/src/suspense.rs @@ -3,7 +3,7 @@ //! The [`Suspense`] component is used to "suspend" execution and wait until async tasks are //! finished before rendering. -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use futures::channel::oneshot; use futures::Future; @@ -56,33 +56,16 @@ pub struct SuspenseProps<'a, G: GenericNode> { /// ``` #[component] pub fn Suspense<'a, G: GenericNode>(cx: Scope<'a>, props: SuspenseProps<'a, G>) -> View { - let state = use_context_or_else(cx, SuspenseState::default); - // Get the outer suspense state. - let outer_count = state.async_counts.borrow().last().cloned(); - // Push a new suspense state. - let count = create_rc_signal(0); - state.async_counts.borrow_mut().push(count.clone()); - let ready = create_selector(cx, move || *count.get() == 0); - - let v = props.children.call(cx); - // Pop the suspense state. - state.async_counts.borrow_mut().pop().unwrap(); - - if let Some(outer_state) = outer_count { - outer_state.set(*outer_state.get() + 1); - // We keep track whether outer_state has already been decremented to prevent it from being - // decremented twice. - let completed = create_ref(cx, Cell::new(false)); - create_effect(cx, move || { - if !completed.get() && *ready.get() { - outer_state.set(*outer_state.get() - 1); - completed.set(true); - } - }); - } + let v = create_signal(cx, None); + // If the Suspense is nested under another Suspense, we want the other Suspense to await this + // one as well. + suspense_scope(cx, async move { + let res = await_suspense(cx, async move { props.children.call(cx) }).await; + v.set(Some(res)); + }); view! { cx, - (if *ready.get() { v.clone() } else { props.fallback.clone() }) + (if let Some(v) = v.get().as_ref() { v.clone() } else { props.fallback.clone() }) } } @@ -106,6 +89,8 @@ pub fn suspense_scope<'a>(cx: Scope<'a>, f: impl Future + 'a) { } /// Waits until all suspense tasks created within the scope are finished. +/// If called inside an outer suspense scope, this will also make the outer suspense scope suspend +/// until this resolves. pub async fn await_suspense(cx: Scope<'_>, f: impl Future) -> U { let state = use_context_or_else(cx, SuspenseState::default); // Get the outer suspense state. @@ -126,7 +111,7 @@ pub async fn await_suspense(cx: Scope<'_>, f: impl Future) -> U { state.async_counts.borrow_mut().pop().unwrap(); let (sender, receiver) = oneshot::channel(); - let sender = create_ref(cx, RefCell::new(Some(sender))); + let mut sender = Some(sender); create_effect(cx, move || { if *ready.get() { @@ -136,7 +121,7 @@ pub async fn await_suspense(cx: Scope<'_>, f: impl Future) -> U { } }); let _ = receiver.await; - if let Some(outer_count) = &outer_count { + if let Some(outer_count) = outer_count { outer_count.set(*outer_count.get() - 1); } ret @@ -158,11 +143,12 @@ impl<'a> TransitionHandle<'a> { } /// Start a transition. - pub fn start(self, f: impl Fn() + 'a) { + pub fn start(self, f: impl FnOnce() + 'a, done: impl FnOnce() + 'a) { spawn_local_scoped(self.cx, async move { self.is_pending.set(true); await_suspense(self.cx, async move { f() }).await; self.is_pending.set(false); + done(); }); } } @@ -207,7 +193,9 @@ mod tests { #[tokio::test] async fn transition() { provide_executor_scope(async { - create_scope_immediate(|cx| { + let (sender, receiver) = oneshot::channel(); + let mut sender = Some(sender); + let disposer = create_scope(|cx| { let trigger = create_signal(cx, ()); let transition = use_transition(cx); let _: View = view! { cx, @@ -217,12 +205,20 @@ mod tests { trigger.track(); assert!(try_use_context::(cx).is_some()); }); - View::empty() + view! { cx, } }) } }; - transition.start(|| trigger.set(())); + let done = create_signal(cx, false); + transition.start(|| trigger.set(()), || done.set(true)); + create_effect(cx, move || { + if *done.get() { + sender.take().unwrap().send(()).unwrap(); + } + }) }); + receiver.await.unwrap(); + unsafe { disposer.dispose() }; }) .await; }