From 51d9564eb520982a2d160b1ab9f153a032007ad8 Mon Sep 17 00:00:00 2001 From: 191220029 <522023330025@smail.nju.edu.cn> Date: Sat, 21 Dec 2024 17:10:00 +0800 Subject: [PATCH] examples & fix of async graph run Signed-off-by: A-Mavericks <363136637@qq.com> Fix auto_node example hello_dagrs & compute_dag example hello_dagrs & compute_dag --- Cargo.toml | 1 + derive/src/auto_node.rs | 11 +- derive/src/relay.rs | 10 +- examples/auto_node.rs | 9 +- examples/auto_relay.rs | 60 +++++------ examples/compute_dag.rs | 106 +++++++++++++++++++ examples/custom_node.rs | 75 ++++++++++++++ examples/hello_dagrs.rs | 46 ++++++++ src/connection/in_channel.rs | 35 ++++++- src/connection/out_channel.rs | 21 +++- src/graph/graph.rs | 190 +++++++++++++++------------------- src/lib.rs | 1 + src/node/default_node.rs | 32 +++--- src/node/node.rs | 7 +- src/utils/execstate.rs | 6 -- src/utils/output.rs | 4 +- 16 files changed, 438 insertions(+), 176 deletions(-) create mode 100644 examples/compute_dag.rs create mode 100644 examples/custom_node.rs create mode 100644 examples/hello_dagrs.rs diff --git a/Cargo.toml b/Cargo.toml index e66d794..720bb01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ env_logger = "0.10.1" async-trait = "0.1.83" derive = { path = "derive", optional = true } proc-macro2 = "1.0" +futures = "0.3.31" [dev-dependencies] simplelog = "0.12" diff --git a/derive/src/auto_node.rs b/derive/src/auto_node.rs index 91e399e..fd612c3 100644 --- a/derive/src/auto_node.rs +++ b/derive/src/auto_node.rs @@ -116,6 +116,7 @@ fn auto_impl_node( ]); quote::quote!( + #[async_trait::async_trait] impl #generics dagrs::Node for #struct_ident #generics { #impl_tokens } @@ -169,12 +170,10 @@ fn impl_run( let in_channels_ident = &field_in_channels.ident; let out_channels_ident = &field_out_channels.ident; quote::quote!( - fn run(&mut self, env: std::sync::Arc) -> dagrs::Output { - tokio::runtime::Runtime::new().unwrap().block_on(async { - self.#ident - .run(&mut self.#in_channels_ident, &self.#out_channels_ident, env) - .await - }) + async fn run(&mut self, env: std::sync::Arc) -> dagrs::Output { + self.#ident + .run(&mut self.#in_channels_ident, &self.#out_channels_ident, env) + .await } ) } diff --git a/derive/src/relay.rs b/derive/src/relay.rs index f586f30..d034888 100644 --- a/derive/src/relay.rs +++ b/derive/src/relay.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use proc_macro2::Ident; use syn::{parse::Parse, Token}; @@ -77,16 +77,16 @@ pub(crate) fn add_relay(relaies: Relaies) -> proc_macro2::TokenStream { } for relay in relaies.0.iter() { let task = relay.task.clone(); - if (!cache.contains(&task)) { + if !cache.contains(&task) { token.extend(quote::quote!( - graph.add_node(Box::new(#task)); + graph.add_node(#task); )); cache.insert(task); } for successor in relay.successors.iter() { - if (!cache.contains(successor)) { + if !cache.contains(successor) { token.extend(quote::quote!( - graph.add_node(Box::new(#successor)); + graph.add_node(#successor); )); cache.insert(successor.clone()); } diff --git a/examples/auto_node.rs b/examples/auto_node.rs index a6653f4..0585739 100644 --- a/examples/auto_node.rs +++ b/examples/auto_node.rs @@ -1,3 +1,7 @@ +//! # Example: auto_node +//! The procedural macro `auto_node` simplifies the implementation of `Node` trait for custom types. +//! It works on structs except [tuple structs](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#using-tuple-structs-without-named-fields-to-create-different-types). + use std::sync::Arc; use dagrs::{auto_node, EmptyAction, EnvVar, InChannels, Node, NodeTable, OutChannels}; @@ -7,6 +11,7 @@ struct MyNode {/*Put customized fields here.*/} #[auto_node] struct _MyNodeGeneric { + /*Put customized fields here.*/ my_field: Vec, my_name: &'a str, } @@ -30,7 +35,9 @@ fn main() { assert_eq!(&s.id(), node_table.get(&node_name).unwrap()); assert_eq!(&s.name(), &node_name); - let output = s.run(Arc::new(EnvVar::new(NodeTable::default()))); + let output = tokio::runtime::Runtime::new() + .unwrap() + .block_on(async { s.run(Arc::new(EnvVar::new(NodeTable::default()))).await }); match output { dagrs::Output::Out(content) => assert!(content.is_none()), _ => panic!(), diff --git a/examples/auto_relay.rs b/examples/auto_relay.rs index c14d1ca..42baaa0 100644 --- a/examples/auto_relay.rs +++ b/examples/auto_relay.rs @@ -1,45 +1,37 @@ -use std::sync::Arc; +//! # Example: auto_relay +//! The macro `dependencies!` simplifies the construction of a `Graph`, +//! including the addition of nodes and edges. -use dagrs::{ - auto_node, dependencies, - graph::{self, graph::Graph}, - EmptyAction, EnvVar, InChannels, Node, NodeTable, OutChannels, -}; +use dagrs::{auto_node, dependencies, EmptyAction, InChannels, Node, NodeTable, OutChannels}; #[auto_node] struct MyNode {/*Put customized fields here.*/} +impl MyNode { + fn new(name: &str, node_table: &mut NodeTable) -> Self { + Self { + id: node_table.alloc_id_for(name), + name: name.to_string(), + input_channels: InChannels::default(), + output_channels: OutChannels::default(), + action: Box::new(EmptyAction), + } + } +} + fn main() { let mut node_table = NodeTable::default(); - let node_name = "auto_node".to_string(); - - let s = MyNode { - id: node_table.alloc_id_for(&node_name), - name: node_name.clone(), - input_channels: InChannels::default(), - output_channels: OutChannels::default(), - action: Box::new(EmptyAction), - }; - - let a = MyNode { - id: node_table.alloc_id_for(&node_name), - name: node_name.clone(), - input_channels: InChannels::default(), - output_channels: OutChannels::default(), - action: Box::new(EmptyAction), - }; - - let b = MyNode { - id: node_table.alloc_id_for(&node_name), - name: node_name.clone(), - input_channels: InChannels::default(), - output_channels: OutChannels::default(), - action: Box::new(EmptyAction), - }; - let mut g = dependencies!(s -> a b, - b -> a + let node_name = "auto_node"; + + let s = MyNode::new(node_name, &mut node_table); + let a = MyNode::new(node_name, &mut node_table); + let b = MyNode::new(node_name, &mut node_table); + + let mut g = dependencies!( + s -> a b, + b -> a ); - g.run(); + g.start(); } diff --git a/examples/compute_dag.rs b/examples/compute_dag.rs new file mode 100644 index 0000000..011294c --- /dev/null +++ b/examples/compute_dag.rs @@ -0,0 +1,106 @@ +//! Only use Dag, execute a job. The graph is as follows: +//! +//! ↱----------↴ +//! B -→ E --→ G +//! ↗ ↗ ↗ +//! A --→ C / +//! ↘ ↘ / +//! D -→ F +//! +//! The final execution result is 272. + +use std::sync::Arc; + +use async_trait::async_trait; +use dagrs::{ + Action, Content, DefaultNode, EnvVar, Graph, InChannels, Node, NodeTable, OutChannels, Output, +}; + +const BASE: &str = "base"; + +struct Compute(usize); + +#[async_trait] +impl Action for Compute { + async fn run( + &self, + in_channels: &mut InChannels, + out_channels: &OutChannels, + env: Arc, + ) -> Output { + let base = env.get::(BASE).unwrap(); + let mut sum = self.0; + + in_channels + .map(|content| content.unwrap().into_inner::().unwrap()) + .await + .into_iter() + .for_each(|x| sum += *x * base); + + out_channels.broadcast(Content::new(sum)).await; + + Output::Out(Some(Content::new(sum))) + } +} + +fn main() { + // Initialization log. + env_logger::init(); + + // Create a new `NodeTable`. + let mut node_table = NodeTable::default(); + + // Generate some tasks. + let a = DefaultNode::with_action("Compute A".to_string(), Compute(1), &mut node_table); + let a_id = a.id(); + + let b = DefaultNode::with_action("Compute B".to_string(), Compute(2), &mut node_table); + let b_id = b.id(); + + let mut c = DefaultNode::new("Compute C".to_string(), &mut node_table); + c.set_action(Compute(4)); + let c_id = c.id(); + + let mut d = DefaultNode::new("Compute D".to_string(), &mut node_table); + d.set_action(Compute(8)); + let d_id = d.id(); + + let e = DefaultNode::with_action("Compute E".to_string(), Compute(16), &mut node_table); + let e_id = e.id(); + let f = DefaultNode::with_action("Compute F".to_string(), Compute(32), &mut node_table); + let f_id = f.id(); + + let g = DefaultNode::with_action("Compute G".to_string(), Compute(64), &mut node_table); + let g_id = g.id(); + + // Create a graph. + let mut graph = Graph::new(); + vec![a, b, c, d, e, f, g] + .into_iter() + .for_each(|node| graph.add_node(node)); + + // Set up task dependencies. + graph.add_edge(a_id, vec![b_id, c_id, d_id]); + graph.add_edge(b_id, vec![e_id, g_id]); + graph.add_edge(c_id, vec![e_id, f_id]); + graph.add_edge(d_id, vec![f_id]); + graph.add_edge(e_id, vec![g_id]); + graph.add_edge(f_id, vec![g_id]); + + // Set a global environment variable for this dag. + let mut env = EnvVar::new(node_table); + env.set("base", 2usize); + graph.set_env(env); + + // Start executing this dag. + graph.start(); + + // Verify execution result. + let res = graph + .get_results::() + .get(&g_id) + .unwrap() + .clone() + .unwrap(); + assert_eq!(*res, 272) +} diff --git a/examples/custom_node.rs b/examples/custom_node.rs new file mode 100644 index 0000000..585986d --- /dev/null +++ b/examples/custom_node.rs @@ -0,0 +1,75 @@ +//! # Example: custom_node +//! Creates a custom implementation of [`Node`] that returns a [`String`], +//! then create a new [`Graph`] with this node and run. + +use std::sync::Arc; + +use async_trait::async_trait; +use dagrs::{ + Content, EnvVar, Graph, InChannels, Node, NodeId, NodeName, NodeTable, OutChannels, Output, +}; + +struct MessageNode { + id: NodeId, + name: NodeName, + in_channels: InChannels, + out_channels: OutChannels, + /*Put your custom fields here.*/ + message: String, +} + +#[async_trait] +impl Node for MessageNode { + fn id(&self) -> NodeId { + self.id + } + + fn name(&self) -> NodeName { + self.name.clone() + } + + fn input_channels(&mut self) -> &mut InChannels { + &mut self.in_channels + } + + fn output_channels(&mut self) -> &mut OutChannels { + &mut self.out_channels + } + + async fn run(&mut self, _: Arc) -> Output { + Output::Out(Some(Content::new(self.message.clone()))) + } +} + +impl MessageNode { + fn new(name: String, node_table: &mut NodeTable) -> Self { + Self { + id: node_table.alloc_id_for(&name), + name, + in_channels: InChannels::default(), + out_channels: OutChannels::default(), + message: "hello dagrs".to_string(), + } + } +} + +fn main() { + // create an empty `NodeTable` + let mut node_table = NodeTable::new(); + // create a `MessageNode` + let node = MessageNode::new("message node".to_string(), &mut node_table); + let id: &dagrs::NodeId = &node.id(); + + // create a graph with this node and run + let mut graph = Graph::new(); + graph.add_node(node); + graph.start(); + + // verify the output of this node + let outputs = graph.get_outputs(); + assert_eq!(outputs.len(), 1); + + let content = outputs.get(id).unwrap().get_out().unwrap(); + let node_output = content.get::().unwrap(); + assert_eq!(node_output, "hello dagrs") +} diff --git a/examples/hello_dagrs.rs b/examples/hello_dagrs.rs new file mode 100644 index 0000000..38e3bed --- /dev/null +++ b/examples/hello_dagrs.rs @@ -0,0 +1,46 @@ +//! # Example: hello_dagrs +//! Creates a `DefaultNode` that returns with "Hello Dagrs", +//! then create a new `Graph` with this node and run. + +use std::sync::Arc; + +use async_trait::async_trait; +use dagrs::{ + Action, Content, DefaultNode, EnvVar, Graph, InChannels, Node, NodeTable, OutChannels, Output, +}; + +/// An implementation of [`Action`] that returns [`Output::Out`] containing a String "Hello world". +#[derive(Default)] +pub struct HelloAction; +#[async_trait] +impl Action for HelloAction { + async fn run(&self, _: &mut InChannels, _: &OutChannels, _: Arc) -> Output { + Output::Out(Some(Content::new("Hello Dagrs".to_string()))) + } +} + +fn main() { + // create an empty `NodeTable` + let mut node_table = NodeTable::new(); + // create a `DefaultNode` with action `HelloAction` + let hello_node = DefaultNode::with_action( + "Hello Dagrs".to_string(), + HelloAction::default(), + &mut node_table, + ); + let id: &dagrs::NodeId = &hello_node.id(); + + // create a graph with this node and run + let mut graph = Graph::new(); + graph.add_node(hello_node); + + graph.start(); + + // verify the output of this node + let outputs = graph.get_outputs(); + assert_eq!(outputs.len(), 1); + + let content = outputs.get(id).unwrap().get_out().unwrap(); + let node_output = content.get::().unwrap(); + assert_eq!(node_output, "Hello Dagrs") +} diff --git a/src/connection/in_channel.rs b/src/connection/in_channel.rs index e79ac19..0acfaa6 100644 --- a/src/connection/in_channel.rs +++ b/src/connection/in_channel.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use futures::future::join_all; use tokio::sync::{broadcast, mpsc, Mutex}; use crate::node::node::NodeId; @@ -28,6 +29,31 @@ impl InChannels { } } + /// Calls `blocking_recv` for all the [`InChannel`]s, and applies transformation `f` to + /// the return values of the call. + pub fn blocking_map(&mut self, mut f: F) -> Vec + where + F: FnMut(Result) -> T, + { + self.keys() + .into_iter() + .map(|id| f(self.blocking_recv_from(&id))) + .collect() + } + + /// Calls `recv` for all the [`InChannel`]s, and applies transformation `f` to + /// the return values of the call asynchronously. + pub async fn map(&mut self, mut f: F) -> Vec + where + F: FnMut(Result) -> T, + { + let futures = self + .0 + .iter_mut() + .map(|(_, c)| async { c.lock().await.recv().await }); + join_all(futures).await.into_iter().map(|x| f(x)).collect() + } + /// Close the channel by the given `NodeId`, and remove the channel in this map. pub fn close(&mut self, id: &NodeId) { if let Some(c) = self.get(id) { @@ -36,14 +62,19 @@ impl InChannels { } } + pub(crate) fn insert(&mut self, node_id: NodeId, channel: Arc>) { + self.0.insert(node_id, channel); + } + fn get(&self, id: &NodeId) -> Option>> { match self.0.get(id) { Some(c) => Some(c.clone()), None => None, } } - pub fn insert(&mut self, node_id: NodeId, channel: Arc>) { - self.0.insert(node_id, channel); + + fn keys(&self) -> Vec { + self.0.keys().map(|x| *x).collect() } } diff --git a/src/connection/out_channel.rs b/src/connection/out_channel.rs index 6dc917d..4469344 100644 --- a/src/connection/out_channel.rs +++ b/src/connection/out_channel.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use futures::future::join_all; use tokio::sync::{broadcast, mpsc}; use crate::node::node::NodeId; @@ -29,6 +30,24 @@ impl OutChannels { } } + /// Broadcasts the `content` to all the [`OutChannel`]s asynchronously. + pub async fn broadcast(&self, content: Content) -> Vec> { + let futures = self + .0 + .iter() + .map(|(_, c)| async { c.send(content.clone()).await }); + + join_all(futures).await + } + + /// Blocking broadcasts the `content` to all the [`OutChannel`]s. + pub fn blocking_broadcast(&self, content: Content) -> Vec> { + self.0 + .iter() + .map(|(_, c)| c.blocking_send(content.clone())) + .collect() + } + /// Close the channel by the given `NodeId`, and remove the channel in this map. pub fn close(&mut self, id: &NodeId) { if let Some(_) = self.get(id) { @@ -43,7 +62,7 @@ impl OutChannels { } } - pub fn insert(&mut self, node_id: NodeId, channel: Arc) { + pub(crate) fn insert(&mut self, node_id: NodeId, channel: Arc) { self.0.insert(node_id, channel); } } diff --git a/src/graph/graph.rs b/src/graph/graph.rs index 67d5fab..cf7a50f 100644 --- a/src/graph/graph.rs +++ b/src/graph/graph.rs @@ -1,5 +1,4 @@ use std::hash::Hash; -use std::sync::mpsc::channel; use std::{ collections::{HashMap, HashSet}, panic::{self, AssertUnwindSafe}, @@ -14,9 +13,9 @@ use crate::{ }; use log::{debug, error}; -use tokio::sync::broadcast; use tokio::sync::mpsc; use tokio::sync::Mutex; +use tokio::task; /// [`Graph`] is dagrs's main body. /// @@ -37,7 +36,7 @@ use tokio::sync::Mutex; pub struct Graph { /// Define the Net struct that holds all nodes - nodes: HashMap>, + nodes: HashMap>>, /// Store a task's running result.Execution results will be read /// and written asynchronously by several threads. execute_states: HashMap>, @@ -74,9 +73,10 @@ impl Graph { } /// Adds a new node to the `Graph` - pub fn add_node(&mut self, node: Box) { - self.node_count = self.node_count + 1; + pub fn add_node(&mut self, node: impl Node + 'static) { let id = node.id(); + let node = Arc::new(Mutex::new(node)); + self.node_count = self.node_count + 1; self.nodes.insert(id, node); self.in_degree.insert(id, 0); } @@ -85,84 +85,46 @@ impl Graph { /// An MPSC channel is used if the outgoing port of the sending node is empty and the number of receiving nodes is equal to 1 /// If the outgoing port of the sending node is not empty, adding any number of receiving nodes will change all relevant channels to broadcast pub fn add_edge(&mut self, from_id: NodeId, all_to_ids: Vec) { - let from_node = self.nodes.get_mut(&from_id).unwrap(); - let from_channel = from_node.output_channels(); let to_ids = Self::remove_duplicates(all_to_ids); - if from_channel.0.is_empty() { - if to_ids.len() > 1 { - let (bcst_sender, _) = broadcast::channel::(32); - { - for to_id in &to_ids { - from_channel - .insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone()))); - self.in_degree - .entry(*to_id) - .and_modify(|e| *e += 1) - .or_insert(0); - } - } - for to_id in &to_ids { - if let Some(to_node) = self.nodes.get_mut(to_id) { - let to_channel = to_node.input_channels(); - let receiver = bcst_sender.subscribe(); - to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Bcst(receiver)))); - } - } - } else if let Some(to_id) = to_ids.get(0) { - let (tx, rx) = mpsc::channel::(32); - { + let mut rx_map: HashMap> = HashMap::new(); + { + let from_node_lock = self.nodes.get_mut(&from_id).unwrap(); + let mut from_node = from_node_lock.blocking_lock(); + let from_channel = from_node.output_channels(); + for to_id in &to_ids { + if !from_channel.0.contains_key(to_id) { + let (tx, rx) = mpsc::channel::(32); from_channel.insert(*to_id, Arc::new(OutChannel::Mpsc(tx.clone()))); + rx_map.insert(*to_id, rx); self.in_degree .entry(*to_id) .and_modify(|e| *e += 1) .or_insert(0); } - if let Some(to_node) = self.nodes.get_mut(to_id) { - let to_channel = to_node.input_channels(); - to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Mpsc(rx)))); - } } - } else { - if to_ids.len() > 1 - || (to_ids.len() == 1 && !from_channel.0.contains_key(to_ids.get(0).unwrap())) - { - let (bcst_sender, _) = broadcast::channel::(32); - { - for _channel in from_channel.0.values_mut() { - *_channel = Arc::new(OutChannel::Bcst(bcst_sender.clone())); - } - for to_id in &to_ids { - if !from_channel.0.contains_key(to_id) { - self.in_degree - .entry(*to_id) - .and_modify(|e| *e += 1) - .or_insert(0); - } - from_channel - .insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone()))); - } - } - for to_id in &to_ids { - if let Some(to_node) = self.nodes.get_mut(to_id) { - let to_channel = to_node.input_channels(); - let receiver = bcst_sender.subscribe(); - to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Bcst(receiver)))); - } + } + for to_id in &to_ids { + if let Some(to_node_lock) = self.nodes.get_mut(to_id) { + let mut to_node = to_node_lock.blocking_lock(); + let to_channel = to_node.input_channels(); + if let Some(rx) = rx_map.remove(&to_id) { + to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Mpsc(rx)))); } } } } /// Initializes the network, setting up the nodes. - pub fn init(&mut self) { + pub(crate) fn init(&mut self) { self.execute_states.reserve(self.nodes.len()); - self.nodes.values().for_each(|node| { + self.nodes.keys().for_each(|node| { self.execute_states - .insert(node.id(), Arc::new(ExecState::new())); + .insert(*node, Arc::new(ExecState::new())); }); } - /// This function is used for the execution of a single net. - pub fn run(&mut self) { + + /// This function is used for the execution of a single dag. + pub fn start(&mut self) { self.init(); let is_loop = self.check_loop(); if is_loop { @@ -172,42 +134,59 @@ impl Graph { eprintln!("Graph is not active. Aborting execution."); return; } else { - for (node_id, node) in &mut self.nodes { - let execute_state = self.execute_states[&node_id].clone(); - panic::catch_unwind(AssertUnwindSafe(|| node.run(Arc::clone(&self.env)))) - .map_or_else( - |_| { + tokio::runtime::Runtime::new() + .unwrap() + .block_on(async { self.run().await }) + } + } + + async fn run(&mut self) { + let mut tasks = Vec::new(); + + for (node_id, node) in &self.nodes { + let execute_state = self.execute_states[&node_id].clone(); + let node_clone = Arc::clone(&self.env); + let node = Arc::clone(&node); + + let task = task::spawn(async move { + // Lock the node before running its method + let mut node = node.lock().await; + let node_name = node.name(); + let node_id = node.id().0; + let result = + panic::catch_unwind(AssertUnwindSafe( + || async move { node.run(node_clone).await }, + )); + + match result { + Ok(out) => { + let out = out.await; + if out.is_err() { + let error = out.get_err().unwrap_or("".to_string()); error!( - "Execution failed [name: {}, id: {}]", - node.name(), - node_id.0, + "Execution failed [name: {}, id: {}] - {}", + node_name, node_id, error ); - }, - |out| { - // Store execution results - if out.is_err() { - let error = out.get_err().unwrap_or("".to_string()); - error!( - "Execution failed [name: {}, id: {}] - {}", - node.name(), - node_id.0, - error - ); - execute_state.set_output(out); - execute_state.exe_fail(); - } else { - execute_state.set_output(out); - execute_state.exe_success(); - debug!( - "Execution succeed [name: {}, id: {}]", - node.name(), - node_id.0 - ); - } - }, - ) - } + execute_state.set_output(out); + execute_state.exe_fail(); + } else { + execute_state.set_output(out); + execute_state.exe_success(); + debug!("Execution succeed [name: {}, id: {}]", node_name, node_id,); + } + } + Err(_) => { + error!("Execution failed [name: {}, id: {}]", node_name, node_id,) + } + } + }); + + tasks.push(task); } + + // Await all tasks to complete + let _ = futures::future::join_all(tasks).await; + self.is_active .store(false, std::sync::atomic::Ordering::Relaxed); } @@ -225,9 +204,10 @@ impl Graph { while let Some(node_id) = queue.pop() { processed_count += 1; - let node = self.nodes.get_mut(&node_id).unwrap(); + let node_lock = self.nodes.get_mut(&node_id).unwrap(); + let mut node = node_lock.blocking_lock(); let out = node.output_channels(); - for (id, channel) in out.0.iter() { + for (id, _channel) in out.0.iter() { if let Some(degree) = in_degree.get_mut(id) { *degree -= 1; if *degree == 0 { @@ -298,8 +278,8 @@ mod tests { } impl HelloAction { - pub fn new() -> Box { - Box::new(Self::default()) + pub fn new() -> Self { + Self::default() } } @@ -331,12 +311,12 @@ mod tests { ); let node1_id = node1.id(); - graph.add_node(Box::new(node)); - graph.add_node(Box::new(node1)); + graph.add_node(node); + graph.add_node(node1); graph.add_edge(node_id, vec![node1_id]); - graph.run(); + graph.start(); let out = graph.execute_states[&node1_id].get_output().unwrap(); let out: &String = out.get().unwrap(); assert_eq!(out, "Hello world"); diff --git a/src/lib.rs b/src/lib.rs index 7c8c8e4..e2a2051 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub use node::{ node::*, }; +pub use async_trait; pub use graph::graph::*; pub use tokio; pub use utils::{env::EnvVar, output::Output}; diff --git a/src/node/default_node.rs b/src/node/default_node.rs index 929a695..00a5971 100644 --- a/src/node/default_node.rs +++ b/src/node/default_node.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use async_trait::async_trait; + use crate::{ connection::{in_channel::InChannels, out_channel::OutChannels}, utils::{env::EnvVar, output::Output}, @@ -39,7 +41,7 @@ use super::{ /// let mut node_table = NodeTable::new(); /// let mut node = DefaultNode::with_action( /// NodeName::from(node_name), -/// Box::new(EmptyAction), +/// EmptyAction, /// &mut node_table, /// ); /// ``` @@ -50,7 +52,7 @@ pub struct DefaultNode { in_channels: InChannels, out_channels: OutChannels, } - +#[async_trait] impl Node for DefaultNode { fn id(&self) -> NodeId { self.id @@ -68,12 +70,10 @@ impl Node for DefaultNode { &mut self.out_channels } - fn run(&mut self, env: Arc) -> Output { - tokio::runtime::Runtime::new().unwrap().block_on(async { - self.action - .run(&mut self.in_channels, &self.out_channels, env) - .await - }) + async fn run(&mut self, env: Arc) -> Output { + self.action + .run(&mut self.in_channels, &self.out_channels, env) + .await } } @@ -90,17 +90,21 @@ impl DefaultNode { pub fn with_action( name: NodeName, - action: Box, + action: impl Action + 'static, node_table: &mut NodeTable, ) -> Self { Self { id: node_table.alloc_id_for(&name), name, - action, + action: Box::new(action), in_channels: InChannels::default(), out_channels: OutChannels::default(), } } + + pub fn set_action(&mut self, action: impl Action + 'static) { + self.action = Box::new(action) + } } #[cfg(test)] @@ -125,8 +129,8 @@ mod test_default_node { } impl HelloAction { - pub fn new() -> Box { - Box::new(Self::default()) + pub fn new() -> Self { + Self::default() } } @@ -150,7 +154,9 @@ mod test_default_node { assert_eq!(node_table.get(node_name).unwrap(), &node.id()); let env = Arc::new(EnvVar::new(node_table)); - let out = node.run(env).get_out().unwrap(); + let out = tokio::runtime::Runtime::new() + .unwrap() + .block_on(async { node.run(env).await.get_out().unwrap() }); let out: &String = out.get().unwrap(); assert_eq!(out, "Hello world"); } diff --git a/src/node/node.rs b/src/node/node.rs index 9188156..d1b11dc 100644 --- a/src/node/node.rs +++ b/src/node/node.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, sync::Arc}; +use async_trait::async_trait; + use crate::{ connection::{in_channel::InChannels, out_channel::OutChannels}, utils::{env::EnvVar, output::Output}, @@ -15,6 +17,7 @@ use super::id_allocate::alloc_id; /// Nodes can communicate with others asynchronously through [`InChannels`] and [`OutChannels`]. /// /// In addition to the above properties, users can also customize some other attributes. +#[async_trait] pub trait Node: Send + Sync { /// id is the unique identifier of each node, it will be assigned by the [`NodeTable`] /// when creating a new node, you can find this node through this identifier. @@ -26,7 +29,7 @@ pub trait Node: Send + Sync { /// Output Channels of this node. fn output_channels(&mut self) -> &mut OutChannels; /// Execute a run of this node. - fn run(&mut self, env: Arc) -> Output; + async fn run(&mut self, env: Arc) -> Output; } #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] @@ -56,10 +59,12 @@ impl NodeTable { id } + /// Get the [`NodeId`] of the node corresponding to its name. pub fn get(&self, name: &str) -> Option<&NodeId> { self.0.get(name) } + /// Create an empty [`NodeTable`]. pub fn new() -> Self { Self::default() } diff --git a/src/utils/execstate.rs b/src/utils/execstate.rs index 47a94c8..e92c919 100644 --- a/src/utils/execstate.rs +++ b/src/utils/execstate.rs @@ -48,12 +48,6 @@ impl ExecState { self.output.lock().unwrap().clone() } - /// The task execution succeed or not. - /// `true` means no panic occurs. - pub(crate) fn success(&self) -> bool { - self.success.load(Ordering::Relaxed) - } - pub(crate) fn exe_success(&self) { self.success.store(true, Ordering::Relaxed) } diff --git a/src/utils/output.rs b/src/utils/output.rs index 4ac4668..7057a15 100644 --- a/src/utils/output.rs +++ b/src/utils/output.rs @@ -68,7 +68,7 @@ impl Output { } /// Get the contents of [`Output`]. - pub(crate) fn get_out(&self) -> Option { + pub fn get_out(&self) -> Option { match self { Self::Out(ref out) => out.clone(), Self::Err(_) | Self::ErrWithExitCode(_, _) => None, @@ -76,7 +76,7 @@ impl Output { } /// Get error information stored in [`Output`]. - pub(crate) fn get_err(&self) -> Option { + pub fn get_err(&self) -> Option { match self { Self::Out(_) => None, Self::Err(err) => Some(err.to_string()),