Skip to content

Commit

Permalink
replace AlgorithmResult with NodeState in rust
Browse files Browse the repository at this point in the history
  • Loading branch information
ljeub-pometry committed Jan 22, 2025
1 parent c109dda commit facd73f
Show file tree
Hide file tree
Showing 30 changed files with 744 additions and 1,637 deletions.
728 changes: 0 additions & 728 deletions raphtory/src/algorithms/algorithm_result.rs

This file was deleted.

30 changes: 9 additions & 21 deletions raphtory/src/algorithms/centrality/betweenness.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::{
algorithms::algorithm_result::AlgorithmResult,
core::entities::VID,
db::graph::node::NodeView,
db::{api::state::NodeState, graph::node::NodeView},
prelude::{GraphViewOps, NodeViewOps},
};
use ordered_float::OrderedFloat;
use std::collections::{HashMap, VecDeque};

/// Computes the betweenness centrality for nodes in a given graph.
Expand All @@ -17,14 +15,14 @@ use std::collections::{HashMap, VecDeque};
///
/// # Returns
///
/// An [AlgorithmResult] containing the betweenness centrality of each node.
/// A NodeState containing the betweenness centrality of each node.
pub fn betweenness_centrality<'graph, G: GraphViewOps<'graph>>(
g: &'graph G,
k: Option<usize>,
normalized: bool,
) -> AlgorithmResult<G, f64, OrderedFloat<f64>> {
) -> NodeState<'graph, f64, G> {
// Initialize a hashmap to store betweenness centrality values.
let mut betweenness: HashMap<usize, f64> = HashMap::new();
let mut betweenness: Vec<f64> = vec![0.0; g.unfiltered_num_nodes()];

// Get the nodes and the total number of nodes in the graph.
let nodes = g.nodes();
Expand Down Expand Up @@ -86,8 +84,7 @@ pub fn betweenness_centrality<'graph, G: GraphViewOps<'graph>>(
delta.insert(*v, new_delta_v);
}
if w != node.node.0 {
let updated_betweenness = betweenness.entry(w).or_insert(0.0);
*updated_betweenness += delta[&w];
betweenness[w] += delta[&w];
}
}
}
Expand All @@ -96,20 +93,11 @@ pub fn betweenness_centrality<'graph, G: GraphViewOps<'graph>>(
if normalized {
let factor = 1.0 / ((n as f64 - 1.0) * (n as f64 - 2.0));
for node in nodes.iter() {
betweenness
.entry(node.node.0)
.and_modify(|v| *v *= factor)
.or_insert(0.0f64);
}
} else {
for node in nodes.iter() {
betweenness.entry(node.node.0).or_insert(0.0f64);
betweenness[node.node.index()] *= factor;
}
}

// Construct and return the AlgorithmResult
let results_type = std::any::type_name::<f64>();
AlgorithmResult::new(g.clone(), "Betweenness", results_type, betweenness)
NodeState::new_from_eval(g.clone(), betweenness)
}

#[cfg(test)]
Expand Down Expand Up @@ -148,7 +136,7 @@ mod betweenness_centrality_test {
expected.insert("6".to_string(), 0.0);

let res = betweenness_centrality(graph, None, false);
assert_eq!(res.get_all_with_names(), expected);
assert_eq!(res, expected);

let mut expected: HashMap<String, f64> = HashMap::new();
expected.insert("1".to_string(), 0.0);
Expand All @@ -158,7 +146,7 @@ mod betweenness_centrality_test {
expected.insert("5".to_string(), 0.0);
expected.insert("6".to_string(), 0.0);
let res = betweenness_centrality(graph, None, true);
assert_eq!(res.get_all_with_names(), expected);
assert_eq!(res, expected);
});
}
}
79 changes: 22 additions & 57 deletions raphtory/src/algorithms/centrality/degree_centrality.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
use crate::{
algorithms::{algorithm_result::AlgorithmResult, metrics::degree::max_degree},
core::state::{accumulator_id::accumulators::sum, compute_state::ComputeStateVec},
db::{
api::view::StaticGraphViewOps,
task::{
context::Context,
node::eval_node::EvalNodeView,
task::{ATask, Job, Step},
task_runner::TaskRunner,
},
},
db::api::{state::NodeState, view::StaticGraphViewOps},
prelude::*,
};
use ordered_float::OrderedFloat;
use rayon::prelude::*;

/// Computes the degree centrality of all nodes in the graph. The values are normalized
/// by dividing each result with the maximum possible degree. Graphs with self-loops can have
Expand All @@ -21,48 +11,24 @@ use ordered_float::OrderedFloat;
/// # Arguments
///
/// - `g`: A reference to the graph.
/// - `threads` - Number of threads to use
///
/// # Returns
///
/// An [AlgorithmResult] containing the degree centrality of each node.
pub fn degree_centrality<G: StaticGraphViewOps>(
g: &G,
threads: Option<usize>,
) -> AlgorithmResult<G, f64, OrderedFloat<f64>> {
let max_degree = max_degree(g);

let mut ctx: Context<G, ComputeStateVec> = g.into();

let min = sum(0);

ctx.agg(min);
pub fn degree_centrality<G: StaticGraphViewOps>(g: &G) -> NodeState<'static, f64, G> {
let max_degree = match g.nodes().degree().max() {
None => return NodeState::new_empty(g.clone()),
Some(v) => v,
};

let step1 = ATask::new(move |evv: &mut EvalNodeView<_, ()>| {
// The division below is fine as floating point division of 0.0
// causes the result to be an NaN
let res = evv.degree() as f64 / max_degree as f64;
if res.is_nan() || res.is_infinite() {
evv.global_update(&min, 0.0);
} else {
evv.update(&min, res);
}
Step::Done
});
let values: Vec<_> = g
.nodes()
.degree()
.into_par_iter_values()
.map(|v| (v as f64) / max_degree as f64)
.collect();

let mut runner: TaskRunner<G, _> = TaskRunner::new(ctx);
let runner_result = runner.run(
vec![],
vec![Job::new(step1)],
None,
|_, ess, _, _| ess.finalize(&min, |min| min),
threads,
1,
None,
None,
);
let results_type = std::any::type_name::<f64>();
AlgorithmResult::new(g.clone(), "Degree Centrality", results_type, runner_result)
NodeState::new_from_values(g.clone(), values)
}

#[cfg(test)]
Expand All @@ -83,15 +49,14 @@ mod degree_centrality_test {
graph.add_edge(0, *src, *dst, NO_PROPS, None).unwrap();
}
test_storage!(&graph, |graph| {
let mut hash_map_result: HashMap<String, f64> = HashMap::new();
hash_map_result.insert("1".to_string(), 1.0);
hash_map_result.insert("2".to_string(), 1.0);
hash_map_result.insert("3".to_string(), 2.0 / 3.0);
hash_map_result.insert("4".to_string(), 2.0 / 3.0);

let binding = degree_centrality(graph, None);
let res = binding.get_all_with_names();
assert_eq!(res, hash_map_result);
let mut expected: HashMap<String, f64> = HashMap::new();
expected.insert("1".to_string(), 1.0);
expected.insert("2".to_string(), 1.0);
expected.insert("3".to_string(), 2.0 / 3.0);
expected.insert("4".to_string(), 2.0 / 3.0);

let res = degree_centrality(graph);
assert_eq!(res, expected);
});
}
}
58 changes: 14 additions & 44 deletions raphtory/src/algorithms/centrality/hits.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
use std::collections::HashMap;

use num_traits::abs;
use ordered_float::OrderedFloat;

use crate::{
algorithms::algorithm_result::AlgorithmResult,
core::{
entities::VID,
state::{
accumulator_id::accumulators::{max, sum},
compute_state::ComputeStateVec,
},
core::state::{
accumulator_id::accumulators::{max, sum},
compute_state::ComputeStateVec,
},
db::{
api::view::{NodeViewOps, StaticGraphViewOps},
api::{
state::NodeState,
view::{NodeViewOps, StaticGraphViewOps},
},
task::{
context::Context,
node::eval_node::EvalNodeView,
Expand All @@ -22,6 +16,7 @@ use crate::{
},
},
};
use num_traits::abs;

#[derive(Debug, Clone)]
struct Hits {
Expand Down Expand Up @@ -58,7 +53,7 @@ pub fn hits<G: StaticGraphViewOps>(
g: &G,
iter_count: usize,
threads: Option<usize>,
) -> AlgorithmResult<G, (f32, f32), (OrderedFloat<f32>, OrderedFloat<f32>)> {
) -> NodeState<(f32, f32), G> {
let mut ctx: Context<G, ComputeStateVec> = g.into();

let recv_hub_score = sum::<f32>(2);
Expand Down Expand Up @@ -142,44 +137,18 @@ pub fn hits<G: StaticGraphViewOps>(

let mut runner: TaskRunner<G, _> = TaskRunner::new(ctx);

let (hub_scores, auth_scores) = runner.run(
runner.run(
vec![],
vec![Job::new(step2), Job::new(step3), Job::new(step4), step5],
None,
|_, _, _, local| {
let mut hubs = HashMap::new();
let mut auths = HashMap::new();
let nodes = g.nodes();
for node in nodes {
let v_gid = node.name();
let VID(v_id) = node.node;
let hit = &local[v_id];
hubs.insert(v_gid.clone(), hit.hub_score);
auths.insert(v_gid, hit.auth_score);
}
(hubs, auths)
NodeState::new_from_eval_mapped(g.clone(), local, |h| (h.hub_score, h.auth_score))
},
threads,
iter_count,
None,
None,
);

let mut results: HashMap<usize, (f32, f32)> = HashMap::new();

hub_scores.into_iter().for_each(|(k, v)| {
results.insert(g.node(k).unwrap().node.0, (v, 0.0));
});

auth_scores.into_iter().for_each(|(k, v)| {
let vid = g.node(k).unwrap().node.0;
let (a, _) = results.get(&vid).unwrap();
results.insert(vid, (*a, v));
});

let results_type = std::any::type_name::<(f32, f32)>();

AlgorithmResult::new(g.clone(), "Hits", results_type, results)
)
}

#[cfg(test)]
Expand All @@ -189,6 +158,7 @@ mod hits_tests {
prelude::NO_PROPS,
test_storage,
};
use std::collections::HashMap;

use super::*;

Expand Down Expand Up @@ -221,7 +191,7 @@ mod hits_tests {
(8, 1),
]);
test_storage!(&graph, |graph| {
let results = hits(graph, 20, None).get_all_with_names();
let results = hits(graph, 20, None);

assert_eq!(
results,
Expand Down
Loading

0 comments on commit facd73f

Please sign in to comment.