Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser: auto_graph macro #81

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ log = "0.4"
env_logger = "0.10.1"
async-trait = "0.1.83"
derive = { path = "derive", optional = true }
proc-macro2 = "1.0"

[dev-dependencies]
simplelog = "0.12"
Expand All @@ -31,4 +32,4 @@ derive = ["derive/derive"]

[[example]]
name = "auto_node"
required-features = ["derive"]
required-features = ["derive"]
14 changes: 13 additions & 1 deletion derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use proc_macro::TokenStream;

#[cfg(feature = "derive")]
mod auto_node;
mod relay;

/// [`auto_node`] is a macro that may be used when customizing nodes. It can only be
/// marked on named struct or unit struct.
Expand Down Expand Up @@ -38,3 +38,15 @@ pub fn auto_node(args: TokenStream, input: TokenStream) -> TokenStream {
use crate::auto_node::auto_node;
auto_node(args, input).into()
}

/// The [`dependencies!`] macro allows users to specify all task dependencies in an easy-to-understand
/// way. It will return the generated graph structure based on a set of defined dependencies
#[cfg(feature = "derive")]
#[proc_macro]
pub fn dependencies(input: TokenStream) -> TokenStream {
use relay::add_relay;
use relay::Relaies;
let relaies = syn::parse_macro_input!(input as Relaies);
let token = add_relay(relaies);
token.into()
}
106 changes: 106 additions & 0 deletions derive/src/relay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::collections::{HashMap, HashSet};

Check warning on line 1 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused import: `HashMap`

Check warning on line 1 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Check

unused import: `HashMap`

use proc_macro2::Ident;
use syn::{parse::Parse, Token};

/// Parses and processes a set of relay tasks and their successors, and generates a directed graph.
///
/// Step 1: Define the `Relay` struct with a task and its associated successors (other tasks that depend on it).
///
/// Step 2: Implement the `Parse` trait for `Relaies` to parse a sequence of task-successor pairs from input. This creates a vector of `Relay` objects.
///
/// Step 3: In `add_relay`, initialize a directed graph structure using `Graph` and a hash map to store edges between nodes.
///
/// Step 4: Iterate through each `Relay` and update the graph's edge list by adding nodes (tasks) and defining edges between tasks and their successors.
///
/// Step 5: Ensure that each task is only added once to the graph using a cache (`HashSet`) to avoid duplicates.
///
/// Step 6: Populate the edges of the graph with the previously processed data and return the graph.
///
/// This code provides the logic to dynamically build a graph based on parsed task relationships, where each task is a node and the successors define directed edges between nodes.

pub(crate) struct Relay {
pub(crate) task: Ident,
pub(crate) successors: Vec<Ident>,
}

pub(crate) struct Relaies(pub(crate) Vec<Relay>);

impl Parse for Relaies {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut relies = Vec::new();
loop {
let mut successors = Vec::new();
let task = input.parse::<Ident>()?;
input.parse::<syn::Token!(->)>()?;
while !input.peek(Token!(,)) && !input.is_empty() {
successors.push(input.parse::<Ident>()?);
}
let relay = Relay { task, successors };
relies.push(relay);
let _ = input.parse::<Token!(,)>();
if input.is_empty() {
break;
}
}
Ok(Self(relies))
}
}

pub(crate) fn add_relay(relaies: Relaies) -> proc_macro2::TokenStream {
let mut token = proc_macro2::TokenStream::new();
let mut cache: HashSet<Ident> = HashSet::new();
token.extend(quote::quote!(
use dagrs::Graph;
use dagrs::NodeId;
use std::collections::HashMap;
use std::collections::HashSet;
let mut edge: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
let mut graph = Graph::new();
));
for relay in relaies.0.iter() {
let task = relay.task.clone();
token.extend(quote::quote!(
let task_id = #task.id();
if(!edge.contains_key(&task_id)){
edge.insert(task_id, HashSet::new());
}
));
for successor in relay.successors.iter() {
token.extend(quote::quote!(
let successor_id = #successor.id();
edge.entry(task_id)
.or_insert_with(HashSet::new)
.insert(successor_id);
));
}
}
for relay in relaies.0.iter() {
let task = relay.task.clone();
if (!cache.contains(&task)) {

Check warning on line 80 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unnecessary parentheses around `if` condition

Check warning on line 80 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Check

unnecessary parentheses around `if` condition
token.extend(quote::quote!(
graph.add_node(Box::new(#task));
));
cache.insert(task);
}
for successor in relay.successors.iter() {
if (!cache.contains(successor)) {

Check warning on line 87 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unnecessary parentheses around `if` condition

Check warning on line 87 in derive/src/relay.rs

View workflow job for this annotation

GitHub Actions / Check

unnecessary parentheses around `if` condition
token.extend(quote::quote!(
graph.add_node(Box::new(#successor));
));
cache.insert(successor.clone());
}
}
}
token.extend(quote::quote!(for (key, value) in &edge {
let vec = value.iter().cloned().collect();
graph.add_edge(key.clone(), vec);
}));

quote::quote!(
{
#token;
graph
}
)
}
45 changes: 45 additions & 0 deletions examples/auto_relay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::sync::Arc;

Check warning on line 1 in examples/auto_relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused import: `std::sync::Arc`

use dagrs::{
auto_node, dependencies,
graph::{self, graph::Graph},

Check warning on line 5 in examples/auto_relay.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused imports: `EnvVar`, `graph::Graph`, and `self`
EmptyAction, EnvVar, InChannels, Node, NodeTable, OutChannels,
};

#[auto_node]
struct MyNode {/*Put customized fields here.*/}

fn main() {
let mut node_table = NodeTable::default();

let node_name = "auto_node".to_string();

let s = MyNode {
id: node_table.alloc_id_for(&node_name),
name: node_name.clone(),
input_channels: InChannels::default(),
output_channels: OutChannels::default(),
action: Box::new(EmptyAction),
};

let a = MyNode {
id: node_table.alloc_id_for(&node_name),
name: node_name.clone(),
input_channels: InChannels::default(),
output_channels: OutChannels::default(),
action: Box::new(EmptyAction),
};

let b = MyNode {
id: node_table.alloc_id_for(&node_name),
name: node_name.clone(),
input_channels: InChannels::default(),
output_channels: OutChannels::default(),
action: Box::new(EmptyAction),
};
let mut g = dependencies!(s -> a b,
b -> a
);

g.run();
}
124 changes: 110 additions & 14 deletions src/graph/graph.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::hash::Hash;
use std::sync::mpsc::channel;

Check warning on line 2 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused import: `std::sync::mpsc::channel`

Check warning on line 2 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Check

unused import: `std::sync::mpsc::channel`
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
panic::{self, AssertUnwindSafe},
sync::{atomic::AtomicBool, Arc},
};
Expand All @@ -8,6 +10,7 @@
connection::{in_channel::InChannel, information_packet::Content, out_channel::OutChannel},
node::node::{Node, NodeId, NodeTable},
utils::{env::EnvVar, execstate::ExecState},
Output,
};

use log::{debug, error};
Expand Down Expand Up @@ -46,6 +49,8 @@
/// Mark whether the net task can continue to execute.
/// When an error occurs during the execution of any task, This flag will still be set to true
is_active: Arc<AtomicBool>,
/// Node's in_degree, used for check loop
in_degree: HashMap<NodeId, usize>,
}

impl Graph {
Expand All @@ -57,6 +62,7 @@
execute_states: HashMap::new(),
env: Arc::new(EnvVar::new(NodeTable::default())),
is_active: Arc::new(AtomicBool::new(true)),
in_degree: HashMap::new(),
}
}

Expand All @@ -70,22 +76,29 @@
/// Adds a new node to the `Graph`
pub fn add_node(&mut self, node: Box<dyn Node>) {
self.node_count = self.node_count + 1;
self.nodes.insert(node.id(), node);
let id = node.id();
self.nodes.insert(id, node);
self.in_degree.insert(id, 0);
}
/// Adds an edge between two nodes in the `Graph`.
/// If the outgoing port of the sending node is empty and the number of receiving nodes is > 1, use the broadcast channel
/// An MPSC channel is used if the outgoing port of the sending node is empty and the number of receiving nodes is equal to 1
/// If the outgoing port of the sending node is not empty, adding any number of receiving nodes will change all relevant channels to broadcast
pub fn add_edge(&mut self, from_id: NodeId, to_ids: Vec<NodeId>) {
pub fn add_edge(&mut self, from_id: NodeId, all_to_ids: Vec<NodeId>) {
let from_node = self.nodes.get_mut(&from_id).unwrap();
let from_channel = from_node.output_channels();
let to_ids = Self::remove_duplicates(all_to_ids);
if from_channel.0.is_empty() {
if to_ids.len() > 1 {
let (bcst_sender, _) = broadcast::channel::<Content>(32);
{
for to_id in &to_ids {
from_channel
.insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone())));
self.in_degree
.entry(*to_id)
.and_modify(|e| *e += 1)
.or_insert(0);
}
}
for to_id in &to_ids {
Expand All @@ -99,27 +112,42 @@
let (tx, rx) = mpsc::channel::<Content>(32);
{
from_channel.insert(*to_id, Arc::new(OutChannel::Mpsc(tx.clone())));
self.in_degree
.entry(*to_id)
.and_modify(|e| *e += 1)
.or_insert(0);
}
if let Some(to_node) = self.nodes.get_mut(to_id) {
let to_channel = to_node.input_channels();
to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Mpsc(rx))));
}
}
} else {
let (bcst_sender, _) = broadcast::channel::<Content>(32);
if to_ids.len() > 1
|| (to_ids.len() == 1 && !from_channel.0.contains_key(to_ids.get(0).unwrap()))
{
for _channel in from_channel.0.values_mut() {
*_channel = Arc::new(OutChannel::Bcst(bcst_sender.clone()));
let (bcst_sender, _) = broadcast::channel::<Content>(32);
{
for _channel in from_channel.0.values_mut() {
*_channel = Arc::new(OutChannel::Bcst(bcst_sender.clone()));
}
for to_id in &to_ids {
if !from_channel.0.contains_key(to_id) {
self.in_degree
.entry(*to_id)
.and_modify(|e| *e += 1)
.or_insert(0);
}
from_channel
.insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone())));
}
}
for to_id in &to_ids {
from_channel.insert(*to_id, Arc::new(OutChannel::Bcst(bcst_sender.clone())));
}
}
for to_id in &to_ids {
if let Some(to_node) = self.nodes.get_mut(to_id) {
let to_channel = to_node.input_channels();
let receiver = bcst_sender.subscribe();
to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Bcst(receiver))));
if let Some(to_node) = self.nodes.get_mut(to_id) {
let to_channel = to_node.input_channels();
let receiver = bcst_sender.subscribe();
to_channel.insert(from_id, Arc::new(Mutex::new(InChannel::Bcst(receiver))));
}
}
}
}
Expand All @@ -136,6 +164,10 @@
/// This function is used for the execution of a single net.
pub fn run(&mut self) {
self.init();
let is_loop = self.check_loop();
if is_loop {
panic!("Graph contains a loop.");
}
if !self.is_active.load(std::sync::atomic::Ordering::Relaxed) {
eprintln!("Graph is not active. Aborting execution.");
return;
Expand Down Expand Up @@ -179,6 +211,70 @@
self.is_active
.store(false, std::sync::atomic::Ordering::Relaxed);
}

///See if the graph has loop
pub fn check_loop(&mut self) -> bool {
let mut queue: Vec<NodeId> = self
.in_degree
.iter()
.filter_map(|(&node_id, &degree)| if degree == 0 { Some(node_id) } else { None })
.collect();

let mut in_degree = self.in_degree.clone();
let mut processed_count = 0;

while let Some(node_id) = queue.pop() {
processed_count += 1;
let node = self.nodes.get_mut(&node_id).unwrap();
let out = node.output_channels();
for (id, channel) in out.0.iter() {

Check warning on line 230 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Unit test & Doc test

unused variable: `channel`

Check warning on line 230 in src/graph/graph.rs

View workflow job for this annotation

GitHub Actions / Check

unused variable: `channel`
if let Some(degree) = in_degree.get_mut(id) {
*degree -= 1;
if *degree == 0 {
queue.push(id.clone());
}
}
}
}
processed_count < self.node_count
}

/// Get the output of all tasks.
pub fn get_results<T: Send + Sync + 'static>(&self) -> HashMap<NodeId, Option<Arc<T>>> {
self.execute_states
.iter()
.map(|(&id, state)| {
let output = match state.get_output() {
Some(content) => content.into_inner(),
None => None,
};
(id, output)
})
.collect()
}
pub fn get_outputs(&self) -> HashMap<NodeId, Output> {
self.execute_states
.iter()
.map(|(&id, state)| {
let t = state.get_full_output();
(id, t)
})
.collect()
}

/// Before the dag starts executing, set the dag's global environment variable.
pub fn set_env(&mut self, env: EnvVar) {
self.env = Arc::new(env);
}

///Remove duplicate elements
fn remove_duplicates<T>(vec: Vec<T>) -> Vec<T>
where
T: Eq + Hash + Clone,
{
let mut seen = HashSet::new();
vec.into_iter().filter(|x| seen.insert(x.clone())).collect()
}
}

#[cfg(test)]
Expand Down
Loading
Loading