Skip to content

Commit

Permalink
feat(hydroflow_plus): add round_robin helpers for networking (#1566)
Browse files Browse the repository at this point in the history
Also fixes compiler crashes when using `.enumerate()` on an un-batched
stream.
  • Loading branch information
shadaj authored Nov 15, 2024
1 parent eb1ad3a commit 22de01f
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 31 deletions.
22 changes: 16 additions & 6 deletions hydroflow_plus/src/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,10 @@ pub enum HfPlusNode {
},

DeferTick(Box<HfPlusNode>),
Enumerate(Box<HfPlusNode>),
Enumerate {
is_static: bool,
input: Box<HfPlusNode>,
},
Inspect {
f: DebugExpr,
input: Box<HfPlusNode>,
Expand Down Expand Up @@ -500,7 +503,7 @@ impl<'a> HfPlusNode {
HfPlusNode::DeferTick(input) => {
transform(input.as_mut(), seen_tees);
}
HfPlusNode::Enumerate(input) => {
HfPlusNode::Enumerate { input, .. } => {
transform(input.as_mut(), seen_tees);
}
HfPlusNode::Inspect { input, .. } => {
Expand Down Expand Up @@ -987,7 +990,7 @@ impl<'a> HfPlusNode {
(defer_tick_ident, input_location_id)
}

HfPlusNode::Enumerate(input) => {
HfPlusNode::Enumerate { is_static, input } => {
let (input_ident, input_location_id) =
input.emit(graph_builders, built_tees, next_stmt_id);

Expand All @@ -998,9 +1001,16 @@ impl<'a> HfPlusNode {
syn::Ident::new(&format!("stream_{}", enumerate_id), Span::call_site());

let builder = graph_builders.entry(input_location_id).or_default();
builder.add_statement(parse_quote! {
#enumerate_ident = #input_ident -> enumerate();
});

if *is_static {
builder.add_statement(parse_quote! {
#enumerate_ident = #input_ident -> enumerate::<'static>();
});
} else {
builder.add_statement(parse_quote! {
#enumerate_ident = #input_ident -> enumerate::<'tick>();
});
}

(enumerate_ident, input_location_id)
}
Expand Down
74 changes: 70 additions & 4 deletions hydroflow_plus/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,23 @@ impl<'a, T, L: Location<'a>, B> Stream<T, L, B> {
}

pub fn enumerate(self) -> Stream<(usize, T), L, B> {
Stream::new(
self.location,
HfPlusNode::Enumerate(Box::new(self.ir_node.into_inner())),
)
if L::is_top_level() {
Stream::new(
self.location,
HfPlusNode::Persist(Box::new(HfPlusNode::Enumerate {
is_static: true,
input: Box::new(HfPlusNode::Unpersist(Box::new(self.ir_node.into_inner()))),
})),
)
} else {
Stream::new(
self.location,
HfPlusNode::Enumerate {
is_static: false,
input: Box::new(self.ir_node.into_inner()),
},
)
}
}

pub fn unique(self) -> Stream<T, L, B>
Expand Down Expand Up @@ -789,6 +802,21 @@ impl<'a, T, L: Location<'a> + NoTick, B> Stream<T, L, B> {
.send_bincode(other)
}

pub fn round_robin_bincode<C2: 'a>(
self,
other: &Cluster<'a, C2>,
) -> Stream<L::Out<T>, Cluster<'a, C2>, Unbounded>
where
L: CanSend<'a, Cluster<'a, C2>, In<T> = (ClusterId<C2>, T)>,
T: Clone + Serialize + DeserializeOwned,
{
let ids = other.members();

self.enumerate()
.map(q!(|(i, w)| (ids[i % ids.len()], w)))
.send_bincode(other)
}

pub fn broadcast_bincode_interleaved<C2: 'a, Tag>(
self,
other: &Cluster<'a, C2>,
Expand All @@ -800,6 +828,17 @@ impl<'a, T, L: Location<'a> + NoTick, B> Stream<T, L, B> {
self.broadcast_bincode(other).map(q!(|(_, b)| b))
}

pub fn round_robin_bincode_interleaved<C2: 'a, Tag>(
self,
other: &Cluster<'a, C2>,
) -> Stream<T, Cluster<'a, C2>, Unbounded>
where
L: CanSend<'a, Cluster<'a, C2>, In<T> = (ClusterId<C2>, T), Out<T> = (Tag, T)> + 'a,
T: Clone + Serialize + DeserializeOwned,
{
self.round_robin_bincode(other).map(q!(|(_, b)| b))
}

pub fn broadcast_bytes<C2: 'a>(
self,
other: &Cluster<'a, C2>,
Expand All @@ -817,6 +856,21 @@ impl<'a, T, L: Location<'a> + NoTick, B> Stream<T, L, B> {
.send_bytes(other)
}

pub fn round_robin_bytes<C2: 'a>(
self,
other: &Cluster<'a, C2>,
) -> Stream<L::Out<Bytes>, Cluster<'a, C2>, Unbounded>
where
L: CanSend<'a, Cluster<'a, C2>, In<Bytes> = (ClusterId<C2>, T)> + 'a,
T: Clone,
{
let ids = other.members();

self.enumerate()
.map(q!(|(i, w)| (ids[i % ids.len()], w)))
.send_bytes(other)
}

pub fn broadcast_bytes_interleaved<C2: 'a, Tag>(
self,
other: &Cluster<'a, C2>,
Expand All @@ -828,6 +882,18 @@ impl<'a, T, L: Location<'a> + NoTick, B> Stream<T, L, B> {
{
self.broadcast_bytes(other).map(q!(|(_, b)| b))
}

pub fn round_robin_bytes_interleaved<C2: 'a, Tag>(
self,
other: &Cluster<'a, C2>,
) -> Stream<Bytes, Cluster<'a, C2>, Unbounded>
where
L: CanSend<'a, Cluster<'a, C2>, In<Bytes> = (ClusterId<C2>, T), Out<Bytes> = (Tag, Bytes)>
+ 'a,
T: Clone,
{
self.round_robin_bytes(other).map(q!(|(_, b)| b))
}
}

#[cfg(test)]
Expand Down
14 changes: 2 additions & 12 deletions hydroflow_plus_test/src/cluster/map_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,8 @@ pub fn map_reduce<'a>(flow: &FlowBuilder<'a>) -> (Process<'a, Leader>, Cluster<'
.source_iter(q!(vec!["abc", "abc", "xyz", "abc"]))
.map(q!(|s| s.to_string()));

let all_ids_vec = cluster.members();
let words_partitioned = words
.tick_batch(&process.tick())
.enumerate()
.map(q!(|(i, w)| (
ClusterId::from_raw((i % all_ids_vec.len()) as u32),
w
)))
.all_ticks();

words_partitioned
.send_bincode(&cluster)
words
.round_robin_bincode(&cluster)
.map(q!(|string| (string, ())))
.tick_batch(&cluster.tick())
.fold_keyed(q!(|| 0), q!(|count, _| *count += 1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ expression: built.ir()
),
),
input: Map {
f: stageleft :: runtime_support :: fn1_type_hint :: < (usize , std :: string :: String) , (hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > , std :: string :: String) > ({ use crate :: __staged :: cluster :: map_reduce :: * ; let all_ids_vec = unsafe { :: std :: mem :: transmute :: < _ , & :: std :: vec :: Vec < hydroflow_plus :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > > > (__hydroflow_plus_cluster_ids_1) } ; | (i , w) | (ClusterId :: from_raw ((i % all_ids_vec . len ()) as u32) , w) }),
input: Enumerate(
Map {
f: stageleft :: runtime_support :: fn1_type_hint :: < (usize , std :: string :: String) , (hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > , std :: string :: String) > ({ use hydroflow_plus :: __staged :: stream :: * ; let ids = unsafe { :: std :: mem :: transmute :: < _ , & :: std :: vec :: Vec < hydroflow_plus :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > > > (__hydroflow_plus_cluster_ids_1) } ; | (i , w) | (ids [i % ids . len ()] , w) }),
input: Enumerate {
is_static: true,
input: Map {
f: stageleft :: runtime_support :: fn1_type_hint :: < & str , std :: string :: String > ({ use crate :: __staged :: cluster :: map_reduce :: * ; | s | s . to_string () }),
input: Source {
source: Iter(
Expand All @@ -91,7 +92,7 @@ expression: built.ir()
),
},
},
),
},
},
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ expression: ir.surface_syntax_string()
---
1v1 = source_iter ({ use crate :: __staged :: cluster :: map_reduce :: * ; vec ! ["abc" , "abc" , "xyz" , "abc"] });
2v1 = map (stageleft :: runtime_support :: fn1_type_hint :: < & str , std :: string :: String > ({ use crate :: __staged :: cluster :: map_reduce :: * ; | s | s . to_string () }));
3v1 = enumerate ();
4v1 = map (stageleft :: runtime_support :: fn1_type_hint :: < (usize , std :: string :: String) , (hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > , std :: string :: String) > ({ use crate :: __staged :: cluster :: map_reduce :: * ; let all_ids_vec = unsafe { :: std :: mem :: transmute :: < _ , & :: std :: vec :: Vec < hydroflow_plus :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > > > (__hydroflow_plus_cluster_ids_1) } ; | (i , w) | (ClusterId :: from_raw ((i % all_ids_vec . len ()) as u32) , w) }));
3v1 = enumerate :: < 'static > ();
4v1 = map (stageleft :: runtime_support :: fn1_type_hint :: < (usize , std :: string :: String) , (hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > , std :: string :: String) > ({ use hydroflow_plus :: __staged :: stream :: * ; let ids = unsafe { :: std :: mem :: transmute :: < _ , & :: std :: vec :: Vec < hydroflow_plus :: ClusterId < hydroflow_plus_test :: cluster :: map_reduce :: Worker > > > (__hydroflow_plus_cluster_ids_1) } ; | (i , w) | (ids [i % ids . len ()] , w) }));
5v1 = map (| (id , data) : (hydroflow_plus :: ClusterId < _ > , std :: string :: String) | { (id . raw_id , hydroflow_plus :: runtime_support :: bincode :: serialize :: < std :: string :: String > (& data) . unwrap () . into ()) });
6v1 = dest_sink ({ use hydroflow_plus :: __staged :: deploy :: deploy_runtime :: * ; let env = FAKE ; let p1_port = "port_0" ; { env . port (p1_port) . connect_local_blocking :: < ConnectedDemux < ConnectedDirect > > () . into_sink () } });
7v1 = source_stream ({ use hydroflow_plus :: __staged :: deploy :: deploy_runtime :: * ; let env = FAKE ; let p2_port = "port_1" ; { env . port (p2_port) . connect_local_blocking :: < ConnectedTagged < ConnectedDirect > > () . into_source () } });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,9 @@ expression: built.ir()
f: stageleft :: runtime_support :: fn1_type_hint :: < (((usize , hydroflow_plus_test :: cluster :: paxos_kv :: KvPayload < u32 , hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: paxos_bench :: Client > >) , usize) , u32) , hydroflow_plus_test :: cluster :: paxos :: P2a < hydroflow_plus_test :: cluster :: paxos_kv :: KvPayload < u32 , hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: paxos_bench :: Client > > > > ({ use crate :: __staged :: cluster :: paxos :: * ; let p_id = hydroflow_plus :: ClusterId :: < hydroflow_plus_test :: cluster :: paxos :: Proposer > :: from_raw (__hydroflow_plus_cluster_self_id_0) ; move | (((index , payload) , next_slot) , ballot_num) | P2a { ballot : Ballot { num : ballot_num , proposer_id : p_id } , slot : next_slot + index , value : Some (payload) } }),
input: CrossSingleton(
CrossSingleton(
Enumerate(
Map {
Enumerate {
is_static: false,
input: Map {
f: stageleft :: runtime_support :: fn1_type_hint :: < (hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: paxos_bench :: Client > , hydroflow_plus_test :: cluster :: paxos_kv :: KvPayload < u32 , hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: paxos_bench :: Client > >) , hydroflow_plus_test :: cluster :: paxos_kv :: KvPayload < u32 , hydroflow_plus :: location :: cluster :: ClusterId < hydroflow_plus_test :: cluster :: paxos_bench :: Client > > > ({ use hydroflow_plus :: __staged :: stream :: * ; | (_ , b) | b }),
input: Network {
from_location: Cluster(
Expand Down Expand Up @@ -601,7 +602,7 @@ expression: built.ir()
},
},
},
),
},
Tee {
inner: <tee 12>,
},
Expand Down

0 comments on commit 22de01f

Please sign in to comment.