From ed8e4a21cfe08d6f29684c532bcc7f263ce5c6c1 Mon Sep 17 00:00:00 2001
From: Kestrer <kestrer.dev@gmail.com>
Date: Mon, 8 Mar 2021 20:12:39 +0000
Subject: [PATCH] Replace functions with Signal type (#20)

* Replace functions with Signal type

* Use DOS line endings

* Add Signal::handle

* Pass handler as a callback to the dependencies
---
 examples/components/src/main.rs |  11 +-
 examples/counter/src/main.rs    |  15 +-
 examples/hello/src/main.rs      |  21 +-
 maple-core/src/lib.rs           |   5 +-
 maple-core/src/reactive.rs      | 453 ++++++++++++++++++--------------
 5 files changed, 278 insertions(+), 227 deletions(-)

diff --git a/examples/components/src/main.rs b/examples/components/src/main.rs
index b635807b0..c3a1e8e6f 100644
--- a/examples/components/src/main.rs
+++ b/examples/components/src/main.rs
@@ -8,7 +8,7 @@ pub fn MyComponent(num: StateHandle<i32>) -> TemplateResult {
             # "My component"
             p {
                 # "Value: "
-                # num()
+                # num.get()
             }
         }
     }
@@ -18,13 +18,12 @@ fn main() {
     console_error_panic_hook::set_once();
     console_log::init_with_level(log::Level::Debug).unwrap();
 
-    let (state, set_state) = create_signal(1);
+    let state = Signal::new(1);
 
     let increment = {
         let state = state.clone();
-        let set_state = set_state.clone();
         move |_| {
-            set_state(*state() + 1);
+            state.set(*state.get() + 1);
         }
     };
 
@@ -34,8 +33,8 @@ fn main() {
                 # "Component demo"
             }
 
-            MyComponent(state.clone())
-            MyComponent(state.clone())
+            MyComponent(state.handle())
+            MyComponent(state.handle())
 
             button(on:click=increment) {
                 # "Increment"
diff --git a/examples/counter/src/main.rs b/examples/counter/src/main.rs
index 985588dde..fe069ef4d 100644
--- a/examples/counter/src/main.rs
+++ b/examples/counter/src/main.rs
@@ -4,26 +4,23 @@ fn main() {
     console_error_panic_hook::set_once();
     console_log::init_with_level(log::Level::Debug).unwrap();
 
-    let (counter, set_counter) = create_signal(0);
+    let counter = Signal::new(0);
 
     create_effect({
         let counter = counter.clone();
         move || {
-            log::info!("Counter value: {}", *counter());
+            log::info!("Counter value: {}", *counter.get());
         }
     });
 
     let increment = {
         let counter = counter.clone();
-        let set_counter = set_counter.clone();
-
-        move |_| set_counter(*counter() + 1)
+        move |_| counter.set(*counter.get() + 1)
     };
 
     let reset = {
-        let set_counter = set_counter.clone();
-
-        move |_| set_counter(0)
+        let counter = counter.clone();
+        move |_| counter.set(0)
     };
 
     let root = template! {
@@ -31,7 +28,7 @@ fn main() {
             # "Counter demo"
             p(class="value") {
                 # "Value: "
-                # counter()
+                # counter.get()
             }
             button(class="increment", on:click=increment) {
                 # "Increment"
diff --git a/examples/hello/src/main.rs b/examples/hello/src/main.rs
index e72da0743..a277e8ef4 100644
--- a/examples/hello/src/main.rs
+++ b/examples/hello/src/main.rs
@@ -8,18 +8,21 @@ fn main() {
     console_error_panic_hook::set_once();
     console_log::init_with_level(log::Level::Debug).unwrap();
 
-    let (name, set_name) = create_signal(String::new());
-
-    let displayed_name = create_memo(move || {
-        if *name() == "" {
-            "World".to_string()
-        } else {
-            name().as_ref().clone()
+    let name = Signal::new(String::new());
+
+    let displayed_name = create_memo({
+        let name = name.clone();
+        move || {
+            if name.get().is_empty() {
+                "World".to_string()
+            } else {
+                name.get().as_ref().clone()
+            }
         }
     });
 
     let handle_change = move |event: Event| {
-        set_name(
+        name.set(
             event
                 .target()
                 .unwrap()
@@ -33,7 +36,7 @@ fn main() {
         div {
             h1 {
                 # "Hello "
-                # displayed_name()
+                # displayed_name.get()
                 # "!"
             }
 
diff --git a/maple-core/src/lib.rs b/maple-core/src/lib.rs
index bd94aa38c..25394c3c7 100644
--- a/maple-core/src/lib.rs
+++ b/maple-core/src/lib.rs
@@ -44,10 +44,7 @@ impl TemplateResult {
 
 /// The maple prelude.
 pub mod prelude {
-    pub use crate::reactive::{
-        create_effect, create_memo, create_selector, create_signal, untracked, SetStateHandle,
-        StateHandle,
-    };
+    pub use crate::reactive::{create_effect, create_memo, create_selector, Signal, StateHandle};
     pub use crate::{render, TemplateResult};
 
     pub use maple_core_macro::template;
diff --git a/maple-core/src/reactive.rs b/maple-core/src/reactive.rs
index 411f3aef2..c6191ea4c 100644
--- a/maple-core/src/reactive.rs
+++ b/maple-core/src/reactive.rs
@@ -1,20 +1,131 @@
 //! Reactive primitives.
 
 use std::cell::RefCell;
+use std::ops::Deref;
 use std::rc::Rc;
 
 /// Returned by functions that provide a handle to access state.
-pub type StateHandle<T> = Rc<dyn Fn() -> Rc<T>>;
+pub struct StateHandle<T: 'static>(Rc<RefCell<SignalInner<T>>>);
+
+impl<T: 'static> StateHandle<T> {
+    /// Get the current value of the state.
+    pub fn get(&self) -> Rc<T> {
+        // if inside an effect, add this signal to dependency list
+        DEPENDENCIES.with(|dependencies| {
+            if dependencies.borrow().is_some() {
+                let signal = self.0.clone();
+
+                dependencies
+                    .borrow_mut()
+                    .as_mut()
+                    .unwrap()
+                    .push(Box::new(move |handler| {
+                        signal.borrow_mut().observe(handler.clone())
+                    }));
+            }
+        });
+
+        self.get_untracked()
+    }
+
+    /// Get the current value of the state, without tracking this as a dependency if inside a
+    /// reactive context.
+    ///
+    /// # Example
+    ///
+    /// ```
+    /// use maple_core::prelude::*;
+    ///
+    /// let state = Signal::new(1);
+    ///
+    /// let double = create_memo({
+    ///     let state = state.clone();
+    ///     move || *state.get_untracked() * 2
+    /// });
+    ///
+    /// assert_eq!(*double.get(), 2);
+    ///
+    /// state.set(2);
+    /// // double value should still be old value because state was untracked
+    /// assert_eq!(*double.get(), 2);
+    /// ```
+    pub fn get_untracked(&self) -> Rc<T> {
+        self.0.borrow().inner.clone()
+    }
+}
+
+impl<T: 'static> Clone for StateHandle<T> {
+    fn clone(&self) -> Self {
+        Self(self.0.clone())
+    }
+}
+
+/// State that can be set.
+pub struct Signal<T: 'static>(StateHandle<T>);
+
+impl<T: 'static> Signal<T> {
+    /// Creates a new signal.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use maple_core::prelude::*;
+    ///
+    /// let state = Signal::new(0);
+    /// assert_eq!(*state.get(), 0);
+    ///
+    /// state.set(1);
+    /// assert_eq!(*state.get(), 1);
+    /// ```
+    pub fn new(value: T) -> Self {
+        Self(StateHandle(Rc::new(RefCell::new(SignalInner::new(value)))))
+    }
+
+    /// Set the current value of the state.
+    ///
+    /// This will notify and update any effects and memos that depend on this value.
+    pub fn set(&self, new_value: T) {
+        match self.0 .0.try_borrow_mut() {
+            Ok(mut signal) => signal.update(new_value),
+            // If the signal is already borrowed, that means it is borrowed in the getter, thus creating a cyclic dependency.
+            Err(_err) => panic!("cannot create cyclic dependency"),
+        }
+        self.0 .0.borrow().trigger_observers();
+    }
 
-/// Returned by functions that provide a closure to modify state.
-pub type SetStateHandle<T> = Rc<dyn Fn(T)>;
+    /// Get the [`StateHandle`] associated with this signal.
+    ///
+    /// This is a shortcut for `(*signal).clone()`.
+    pub fn handle(&self) -> StateHandle<T> {
+        self.0.clone()
+    }
 
-struct Signal<T> {
+    /// Convert this signal into its underlying handle.
+    pub fn into_handle(self) -> StateHandle<T> {
+        self.0
+    }
+}
+
+impl<T: 'static> Deref for Signal<T> {
+    type Target = StateHandle<T>;
+
+    fn deref(&self) -> &Self::Target {
+        &self.0
+    }
+}
+
+impl<T: 'static> Clone for Signal<T> {
+    fn clone(&self) -> Self {
+        Self(self.0.clone())
+    }
+}
+
+struct SignalInner<T> {
     inner: Rc<T>,
     observers: Vec<Rc<Computation>>,
 }
 
-impl<T> Signal<T> {
+impl<T> SignalInner<T> {
     fn new(value: T) -> Self {
         Self {
             inner: Rc::new(value),
@@ -51,168 +162,102 @@ impl<T> Signal<T> {
 /// A derived computation from a signal.
 struct Computation(Box<dyn Fn()>);
 
-thread_local! {
-    static HANDLER: RefCell<Option<Rc<Computation>>> = RefCell::new(None);
+type Dependency = Box<dyn Fn(&Rc<Computation>)>;
 
+thread_local! {
     /// To add the dependencies, iterate through functions and execute them.
-    static DEPENDENCIES: RefCell<Option<Vec<Box<dyn Fn()>>>> = RefCell::new(None);
-}
-
-/// Creates a new signal.
-/// The function will return a pair of getter/setters to modify the signal and update corresponding dependencies.
-///
-/// # Example
-/// ```rust
-/// use maple_core::prelude::*;
-///
-/// let (state, set_state) = create_signal(0);
-/// assert_eq!(*state(), 0);
-///
-/// set_state(1);
-/// assert_eq!(*state(), 1);
-/// ```
-pub fn create_signal<T: 'static>(value: T) -> (StateHandle<T>, SetStateHandle<T>) {
-    let signal = Rc::new(RefCell::new(Signal::new(value)));
-
-    let getter = {
-        let signal = signal.clone();
-        move || {
-            // if inside an effect, add this signal to dependency list
-            DEPENDENCIES.with(|dependencies| {
-                if dependencies.borrow().is_some() {
-                    let signal = signal.clone();
-                    let handler =
-                        HANDLER.with(|handler| handler.borrow().as_ref().unwrap().clone());
-
-                    dependencies
-                        .borrow_mut()
-                        .as_mut()
-                        .unwrap()
-                        .push(Box::new(move || {
-                            signal.borrow_mut().observe(handler.clone())
-                        }));
-                }
-            });
-
-            signal.borrow().inner.clone()
-        }
-    };
-
-    let setter = {
-        let signal = signal.clone();
-        move |new_value| {
-            match signal.try_borrow_mut() {
-                Ok(mut signal) => signal.update(new_value),
-                // If the signal is already borrowed, that means it is borrowed in the getter, thus creating a cyclic dependency.
-                Err(_err) => panic!("cannot create cyclic dependency"),
-            };
-            signal.borrow().trigger_observers();
-        }
-    };
-
-    (Rc::new(getter), Rc::new(setter))
+    static DEPENDENCIES: RefCell<Option<Vec<Dependency>>> = RefCell::new(None);
 }
 
 /// Creates an effect on signals used inside the effect closure.
-pub fn create_effect<F>(effect: F)
-where
-    F: Fn() + 'static,
-{
+///
+/// Unlike [`create_effect`], this will allow the closure to run different code upon first
+/// execution, so it can return a value.
+fn create_effect_initial<R>(initial: impl FnOnce() -> (Rc<Computation>, R)) -> R {
     DEPENDENCIES.with(|dependencies| {
         if dependencies.borrow().is_some() {
             unimplemented!("nested dependencies are not supported")
         }
 
-        let effect = Rc::new(Computation(Box::new(effect)));
-
         *dependencies.borrow_mut() = Some(Vec::new());
-        HANDLER.with(|handler| *handler.borrow_mut() = Some(effect.clone()));
 
         // run effect for the first time to attach all the dependencies
-        effect.0();
+        let (effect, ret) = initial();
 
         // attach dependencies
         for dependency in dependencies.borrow().as_ref().unwrap() {
-            dependency();
+            dependency(&effect);
         }
 
         // Reset dependencies for next effect hook
         *dependencies.borrow_mut() = None;
+
+        ret
     })
 }
 
-/// Prevents tracking dependencies inside the closure. If called outside a reactive context, does nothing.
-///
-/// # Example
-/// ```rust
-/// use maple_core::prelude::*;
-///
-/// let (state, set_state) = create_signal(1);
-///
-/// let double = create_memo(move || untracked(|| *state()) * 2);
-///
-/// assert_eq!(*double(), 2);
-///
-/// set_state(2);
-/// assert_eq!(*double(), 2); // double value should still be old value because state() was inside untracked
-/// ```
-pub fn untracked<F, Out>(f: F) -> Out
+/// Creates an effect on signals used inside the effect closure.
+pub fn create_effect<F>(effect: F)
 where
-    F: Fn() -> Out,
+    F: Fn() + 'static,
 {
-    let tmp = DEPENDENCIES.with(|dependencies| dependencies.take());
-    let out = f();
-    DEPENDENCIES.with(|dependencies| *dependencies.borrow_mut() = tmp);
-
-    out
+    create_effect_initial(move || {
+        effect();
+        (Rc::new(Computation(Box::new(effect))), ())
+    })
 }
 
 /// Creates a memoized value from some signals. Also know as "derived stores".
 pub fn create_memo<F, Out>(derived: F) -> StateHandle<Out>
 where
     F: Fn() -> Out + 'static,
-    Out: Clone + 'static,
+    Out: 'static,
 {
-    let derived = Rc::new(derived);
-    let (memo, set_memo) = create_signal(None);
-
-    create_effect({
-        let derived = derived.clone();
-        move || {
-            set_memo(Some(derived()));
-        }
-    });
-
-    // return memoized result
-    let memo_result = move || Rc::new(Option::as_ref(&memo()).unwrap().clone());
-    Rc::new(memo_result)
+    create_selector_with(derived, |_, _| false)
 }
 
 /// Creates a memoized value from some signals. Also know as "derived stores".
 /// Unlike [`create_memo`], this function will not notify dependents of a change if the output is the same.
-/// That is why the output of the function must implement `PartialEq`.
+/// That is why the output of the function must implement [`PartialEq`].
+///
+/// To specify a custom comparison function, use [`create_selector_with`].
 pub fn create_selector<F, Out>(derived: F) -> StateHandle<Out>
 where
     F: Fn() -> Out + 'static,
-    Out: Clone + PartialEq + std::fmt::Debug + 'static,
+    Out: PartialEq + 'static,
 {
-    let derived = Rc::new(derived);
-    let (memo, set_memo) = create_signal(None);
-
-    create_effect({
-        let derived = derived.clone();
-        let memo = memo.clone();
-        move || {
-            let new_value = Some(derived());
-            if *untracked(|| memo()) != new_value {
-                set_memo(new_value);
+    create_selector_with(derived, PartialEq::eq)
+}
+
+/// Creates a memoized value from some signals. Also know as "derived stores".
+/// Unlike [`create_memo`], this function will not notify dependents of a change if the output is the same.
+///
+/// It takes a comparison function to compare the old and new value, which returns `true` if they
+/// are the same and `false` otherwise.
+///
+/// To use the type's [`PartialEq`] implementation instead of a custom function, use
+/// [`create_selector`].
+pub fn create_selector_with<F, Out, C>(derived: F, comparator: C) -> StateHandle<Out>
+where
+    F: Fn() -> Out + 'static,
+    Out: 'static,
+    C: Fn(&Out, &Out) -> bool + 'static,
+{
+    create_effect_initial(|| {
+        let memo = Signal::new(derived());
+
+        let effect = Rc::new(Computation(Box::new({
+            let memo = memo.clone();
+            move || {
+                let new_value = derived();
+                if !comparator(&memo.get_untracked(), &new_value) {
+                    memo.set(new_value);
+                }
             }
-        }
-    });
+        })));
 
-    // return memoized result
-    let memo_result = move || Rc::new(Option::as_ref(&memo()).unwrap().clone());
-    Rc::new(memo_result)
+        (effect, memo.into_handle())
+    })
 }
 
 #[cfg(test)]
@@ -221,190 +266,200 @@ mod tests {
 
     #[test]
     fn signals() {
-        let (state, set_state) = create_signal(0);
-        assert_eq!(*state(), 0);
+        let state = Signal::new(0);
+        assert_eq!(*state.get(), 0);
 
-        set_state(1);
-        assert_eq!(*state(), 1);
+        state.set(1);
+        assert_eq!(*state.get(), 1);
     }
 
     #[test]
     fn signal_composition() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
-        let double = || *state() * 2;
+        let double = || *state.get() * 2;
 
         assert_eq!(double(), 0);
 
-        set_state(1);
+        state.set(1);
         assert_eq!(double(), 2);
     }
 
     #[test]
     fn effects() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
-        let (double, set_double) = create_signal(-1);
+        let double = Signal::new(-1);
 
         create_effect({
-            let set_double = set_double.clone();
+            let state = state.clone();
+            let double = double.clone();
             move || {
-                set_double(*state() * 2);
+                double.set(*state.get() * 2);
             }
         });
-        assert_eq!(*double(), 0); // calling create_effect should call the effect at least once
+        assert_eq!(*double.get(), 0); // calling create_effect should call the effect at least once
 
-        set_state(1);
-        assert_eq!(*double(), 2);
-        set_state(2);
-        assert_eq!(*double(), 4);
+        state.set(1);
+        assert_eq!(*double.get(), 2);
+        state.set(2);
+        assert_eq!(*double.get(), 4);
     }
 
     #[test]
     #[should_panic(expected = "cannot create cyclic dependency")]
     fn cyclic_effects_fail() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
         create_effect({
             let state = state.clone();
-            let set_state = set_state.clone();
             move || {
-                set_state(*state() + 1);
+                state.set(*state.get() + 1);
             }
         });
 
-        set_state(1);
+        state.set(1);
     }
 
     #[test]
     #[should_panic(expected = "cannot create cyclic dependency")]
     fn cyclic_effects_fail_2() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
         create_effect({
             let state = state.clone();
-            let set_state = set_state.clone();
             move || {
-                let value = *state();
-                set_state(value + 1);
+                let value = *state.get();
+                state.set(value + 1);
             }
         });
 
-        set_state(1);
+        state.set(1);
     }
 
     #[test]
     fn effect_should_subscribe_once() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
-        let (counter, set_counter) = create_signal(0);
+        let counter = Signal::new(0);
         create_effect({
+            let state = state.clone();
             let counter = counter.clone();
             move || {
-                set_counter(untracked(|| *counter()) + 1);
+                counter.set(*counter.get_untracked() + 1);
 
-                // call state() twice but should subscribe once
-                state();
-                state();
+                // call state.get() twice but should subscribe once
+                state.get();
+                state.get();
             }
         });
 
-        assert_eq!(*counter(), 1);
+        assert_eq!(*counter.get(), 1);
 
-        set_state(1);
-        assert_eq!(*counter(), 2);
+        state.set(1);
+        assert_eq!(*counter.get(), 2);
     }
 
     #[test]
     fn memo() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
-        let double = create_memo(move || *state() * 2);
-        assert_eq!(*double(), 0);
+        let double = create_memo({
+            let state = state.clone();
+            move || *state.get() * 2
+        });
+        assert_eq!(*double.get(), 0);
 
-        set_state(1);
-        assert_eq!(*double(), 2);
+        state.set(1);
+        assert_eq!(*double.get(), 2);
 
-        set_state(2);
-        assert_eq!(*double(), 4);
+        state.set(2);
+        assert_eq!(*double.get(), 4);
     }
 
     #[test]
     /// Make sure value is memoized rather than executed on demand.
     fn memo_only_run_once() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
-        let (counter, set_counter) = create_signal(0);
+        let counter = Signal::new(0);
         let double = create_memo({
+            let state = state.clone();
             let counter = counter.clone();
             move || {
-                set_counter(untracked(|| *counter()) + 1);
+                counter.set(*counter.get_untracked() + 1);
 
-                *state() * 2
+                *state.get() * 2
             }
         });
-        assert_eq!(*counter(), 1); // once for calculating initial derived state
+        assert_eq!(*counter.get(), 1); // once for calculating initial derived state
 
-        set_state(2);
-        assert_eq!(*counter(), 2);
-        assert_eq!(*double(), 4);
-        assert_eq!(*counter(), 2); // should still be 2 after access
+        state.set(2);
+        assert_eq!(*counter.get(), 2);
+        assert_eq!(*double.get(), 4);
+        assert_eq!(*counter.get(), 2); // should still be 2 after access
     }
 
     #[test]
     fn dependency_on_memo() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
-        let double = create_memo(move || *state() * 2);
+        let double = create_memo({
+            let state = state.clone();
+            move || *state.get() * 2
+        });
 
-        let quadruple = create_memo(move || *double() * 2);
+        let quadruple = create_memo(move || *double.get() * 2);
 
-        assert_eq!(*quadruple(), 0);
+        assert_eq!(*quadruple.get(), 0);
 
-        set_state(1);
-        assert_eq!(*quadruple(), 4);
+        state.set(1);
+        assert_eq!(*quadruple.get(), 4);
     }
 
     #[test]
     fn untracked_memo() {
-        let (state, set_state) = create_signal(1);
+        let state = Signal::new(1);
 
-        let double = create_memo(move || untracked(|| *state()) * 2);
+        let double = create_memo({
+            let state = state.clone();
+            move || *state.get_untracked() * 2
+        });
 
-        assert_eq!(*double(), 2);
+        assert_eq!(*double.get(), 2);
 
-        set_state(2);
-        assert_eq!(*double(), 2); // double value should still be true because state() was inside untracked
+        state.set(2);
+        assert_eq!(*double.get(), 2); // double value should still be true because state.get() was inside untracked
     }
 
     #[test]
     fn selector() {
-        let (state, set_state) = create_signal(0);
+        let state = Signal::new(0);
 
         let double = create_selector({
             let state = state.clone();
-            move || *state() * 2
+            move || *state.get() * 2
         });
 
-        let (counter, set_counter) = create_signal(0);
+        let counter = Signal::new(0);
         create_effect({
             let counter = counter.clone();
             let double = double.clone();
             move || {
-                set_counter(untracked(|| *counter()) + 1);
+                counter.set(*counter.get_untracked() + 1);
 
-                double();
+                double.get();
             }
         });
-        assert_eq!(*double(), 0);
-        assert_eq!(*counter(), 1);
+        assert_eq!(*double.get(), 0);
+        assert_eq!(*counter.get(), 1);
 
-        set_state(0);
-        assert_eq!(*double(), 0);
-        assert_eq!(*counter(), 1); // calling set_state should not trigger the effect
+        state.set(0);
+        assert_eq!(*double.get(), 0);
+        assert_eq!(*counter.get(), 1); // calling set_state should not trigger the effect
 
-        set_state(2);
-        assert_eq!(*double(), 4);
-        assert_eq!(*counter(), 2);
+        state.set(2);
+        assert_eq!(*double.get(), 4);
+        assert_eq!(*counter.get(), 2);
     }
 }