diff --git a/calyx-ir/src/control.rs b/calyx-ir/src/control.rs index c436513e8..a64005e09 100644 --- a/calyx-ir/src/control.rs +++ b/calyx-ir/src/control.rs @@ -589,6 +589,17 @@ impl Control { let empty = Control::empty(); std::mem::replace(self, empty) } + + /// Replaces &mut self with an empty control statement, and returns StaticControl + /// of self. Note that this only works on Control that is static + pub fn take_static_control(&mut self) -> StaticControl { + let empty = Control::empty(); + let control = std::mem::replace(self, empty); + let Control::Static(static_control) = control else { + unreachable!("Called take_static_control on non-static control") + }; + static_control + } } impl StaticControl { diff --git a/calyx-opt/src/passes/compile_repeat.rs b/calyx-opt/src/passes/compile_repeat.rs index ffe707ea9..17c35ace4 100644 --- a/calyx-opt/src/passes/compile_repeat.rs +++ b/calyx-opt/src/passes/compile_repeat.rs @@ -66,6 +66,10 @@ impl Visitor for CompileRepeat { ) .to_vec(); init_group.borrow_mut().assignments = init_assigns; + init_group + .borrow_mut() + .attributes + .insert(ir::NumAttr::PromoteStatic, 1); // incr_group: // 1) writes results of idx + 1 into idx (i.e., increments idx) // 2) writes the result of (idx + 1 < num_repeats) into cond_reg, @@ -84,6 +88,10 @@ impl Visitor for CompileRepeat { ) .to_vec(); incr_group.borrow_mut().assignments = idx_incr_assigns; + incr_group + .borrow_mut() + .attributes + .insert(ir::NumAttr::PromoteStatic, 1); // create control: // init_group; while cond_reg.out {repeat_body; incr_group;} let while_body = ir::Control::seq(vec![ diff --git a/calyx-opt/src/passes/static_promotion.rs b/calyx-opt/src/passes/static_promotion.rs index f71a4934d..0236941a1 100644 --- a/calyx-opt/src/passes/static_promotion.rs +++ b/calyx-opt/src/passes/static_promotion.rs @@ -1,15 +1,19 @@ -use crate::analysis::{GraphAnalysis, IntoStatic, ReadWriteSet}; +use crate::analysis::{GraphAnalysis, ReadWriteSet}; use crate::traversal::{ Action, ConstructVisitor, Named, Order, VisResult, Visitor, }; use calyx_ir::{self as ir, LibrarySignatures, RRC}; use calyx_utils::{CalyxResult, Error}; -use ir::{GetAttributes, StaticControl}; +use ir::GetAttributes; use itertools::Itertools; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::num::NonZeroU64; use std::rc::Rc; +const APPROX_ENABLE_SIZE: u64 = 1; +const APPROX_IF_SIZE: u64 = 3; +const APPROX_WHILE_REPEAT_SIZE: u64 = 3; + /// Struct to store information about the go-done interfaces defined by a primitive. #[derive(Default, Debug)] struct GoDone { @@ -134,6 +138,12 @@ pub struct StaticPromotion { latency_data: HashMap, /// dynamic group Id -> promoted static group Id static_group_name: HashMap, + /// Maps static component names to their latencies + static_component_latencies: HashMap, + /// Threshold for promotion + threshold: u64, + /// Whether we should stop promoting when we see a loop. + stop_loop: bool, } // Override constructor to build latency_data information from the primitives @@ -167,9 +177,13 @@ impl ConstructVisitor for StaticPromotion { } latency_data.insert(prim.name, GoDone::new(go_ports)); } + let (threshold, stop_loop) = Self::get_threshold(ctx); Ok(StaticPromotion { latency_data, static_group_name: HashMap::new(), + static_component_latencies: HashMap::new(), + threshold, + stop_loop, }) } @@ -190,6 +204,54 @@ impl Named for StaticPromotion { } impl StaticPromotion { + // Looks through ctx to get the given command line threshold. + // Default threshold = 1 + fn get_threshold(ctx: &ir::Context) -> (u64, bool) + where + Self: Named, + { + let n = Self::name(); + + // getting the given opts for -x cell-share:__ + let given_opts: HashSet<_> = ctx + .extra_opts + .iter() + .filter_map(|opt| { + let mut splits = opt.split(':'); + if splits.next() == Some(n) { + splits.next() + } else { + None + } + }) + .collect(); + + let mut stop_loop = false; + given_opts.iter().for_each(|arg| { + if *arg == "stop_loop" { + stop_loop = true + } + }); + + // searching for "-x static-promotion:threshold=1" and getting back "1" + let threshold: Option<&str> = given_opts.iter().find_map(|arg| { + let split: Vec<&str> = arg.split('=').collect(); + if let Some(str) = split.first() { + if str == &"threshold" { + return Some(split[1]); + } + } + None + }); + + // Need to convert string argument into int argument + // Always default to threshold=1 + ( + threshold.unwrap_or("1").parse::().unwrap_or(1), + stop_loop, + ) + } + /// Return true if the edge (`src`, `dst`) meet one these criteria, and false otherwise: /// - `src` is an "out" port of a constant, and `dst` is a "go" port /// - `src` is a "done" port, and `dst` is a "go" port @@ -417,30 +479,45 @@ impl StaticPromotion { Some(latency_sum) } - /// if we've already constructed the static group for the `ir::Enable`, then - /// just promote `Enable` to `StaticEnable` using the constructed `static group` - /// if not, then first construct `static group` and then promote `Enable` to - /// `StaticEnable` - fn construct_static_enable( + /// Gets the inferred latency, which should either be from being a static + /// control operator or the promote_static attribute. + /// Will raise an error if neither of these is true. + fn get_inferred_latency(c: &ir::Control) -> u64 { + let ir::Control::Static(sc) = c else { + let Some(latency) = c.get_attribute(ir::NumAttr::PromoteStatic) else { + unreachable!("Called get_latency on control that is neither static nor promotable") + }; + return latency; + }; + sc.get_latency() + } + + fn check_latencies_match(actual: u64, inferred: u64) { + assert_eq!(actual, inferred, "Inferred and Annotated Latencies do not match. Latency: {}. Inferred: {}", actual, inferred); + } + + /// Returns true if a control statement is already static, or has the static + /// attributes + fn can_be_promoted(c: &ir::Control) -> bool { + c.is_static() || c.has_attribute(ir::NumAttr::PromoteStatic) + } + + /// If we've already constructed the static group then use the already existing + /// group. Otherwise construct `static group` and then return that. + fn construct_static_group( &mut self, builder: &mut ir::Builder, - en: ir::Enable, - ) -> ir::StaticControl { - if let Some(s_name) = - self.static_group_name.get(&en.group.borrow().name()) + group: ir::RRC, + latency: u64, + ) -> ir::RRC { + if let Some(s_name) = self.static_group_name.get(&group.borrow().name()) { - ir::StaticControl::Enable(ir::StaticEnable { - group: builder.component.find_static_group(*s_name).unwrap(), - attributes: ir::Attributes::default(), - }) + builder.component.find_static_group(*s_name).unwrap() } else { - let sg = builder.add_static_group( - en.group.borrow().name(), - en.get_attributes().get(ir::NumAttr::PromoteStatic).unwrap(), - ); + let sg = builder.add_static_group(group.borrow().name(), latency); self.static_group_name - .insert(en.group.borrow().name(), sg.borrow().name()); - for assignment in en.group.borrow().assignments.iter() { + .insert(group.borrow().name(), sg.borrow().name()); + for assignment in group.borrow().assignments.iter() { if !(assignment.dst.borrow().is_hole() && assignment.dst.borrow().name == "done") { @@ -448,69 +525,271 @@ impl StaticPromotion { sg.borrow_mut().assignments.push(static_s); } } - ir::StaticControl::Enable(ir::StaticEnable { - group: Rc::clone(&sg), - attributes: ir::Attributes::default(), - }) + Rc::clone(&sg) } } - fn construct_static_seq( + /// Converts control to static control. + /// Control must already be static or have the `promote_static` attribute. + fn convert_to_static( &mut self, + c: &mut ir::Control, builder: &mut ir::Builder, - static_vec: &mut Vec, - ) -> ir::Control { - let mut latency = 0; - let mut static_seq_st: Vec = Vec::new(); - for s in std::mem::take(static_vec) { - match s { - ir::Control::Static(sc) => { - latency += sc.get_latency(); - static_seq_st.push(sc); - } - ir::Control::Enable(en) => { - let sen = self.construct_static_enable(builder, en); - latency += sen.get_latency(); - static_seq_st.push(sen); - } - _ => unreachable!("We do not insert non-static controls other than group enables with `promote_static` attribute") - } + ) -> ir::StaticControl { + assert!( + c.has_attribute(ir::NumAttr::PromoteStatic) || c.is_static(), + "Called convert_to_static control that is neither static nor promotable" + ); + // Need to get bound_attribute here, because we cannot borrow `c` within the + // pattern match. + let bound_attribute = c.get_attribute(ir::NumAttr::Bound); + // Inferred latency of entire control block. Used to double check our + // function is correct. + let inferred_latency = Self::get_inferred_latency(c); + match c { + ir::Control::Empty(_) => ir::StaticControl::empty(), + ir::Control::Enable(ir::Enable { group, attributes }) => { + // Removing the `promote_static` attribute bc we don't need it anymore. + attributes.remove(ir::NumAttr::PromoteStatic); + let enable = ir::StaticControl::Enable(ir::StaticEnable { + // upgrading group to static group + group: self.construct_static_group( + builder, + Rc::clone(group), + group + .borrow() + .get_attributes() + .unwrap() + .get(ir::NumAttr::PromoteStatic) + .unwrap(), + ), + attributes: std::mem::take(attributes), + }); + enable + } + ir::Control::Seq(ir::Seq { stmts, attributes }) => { + // Removing the `promote_static` attribute bc we don't need it anymore + attributes.remove(ir::NumAttr::PromoteStatic); + // The resulting static seq should be compactable. + attributes.insert(ir::NumAttr::Compactable, 1); + let static_stmts = + self.convert_vec_to_static(builder, std::mem::take(stmts)); + let latency = + static_stmts.iter().map(|s| s.get_latency()).sum(); + Self::check_latencies_match(latency, inferred_latency); + ir::StaticControl::Seq(ir::StaticSeq { + stmts: static_stmts, + attributes: std::mem::take(attributes), + latency, + }) + } + ir::Control::Par(ir::Par { stmts, attributes }) => { + // Removing the `promote_static` attribute bc we don't need it anymore + attributes.remove(ir::NumAttr::PromoteStatic); + // Convert stmts to static + let static_stmts = + self.convert_vec_to_static(builder, std::mem::take(stmts)); + // Calculate latency + let latency = static_stmts + .iter() + .map(|s| s.get_latency()) + .max() + .unwrap_or_else(|| unreachable!("Empty Par Block")); + Self::check_latencies_match(latency, inferred_latency); + ir::StaticControl::Par(ir::StaticPar { + stmts: static_stmts, + attributes: ir::Attributes::default(), + latency, + }) + } + ir::Control::Repeat(ir::Repeat { + body, + num_repeats, + attributes, + }) => { + // Removing the `promote_static` attribute bc we don't need it anymore + attributes.remove(ir::NumAttr::PromoteStatic); + let sc = self.convert_to_static(body, builder); + let latency = (*num_repeats) * sc.get_latency(); + Self::check_latencies_match(latency, inferred_latency); + ir::StaticControl::Repeat(ir::StaticRepeat { + attributes: std::mem::take(attributes), + body: Box::new(sc), + num_repeats: *num_repeats, + latency, + }) + } + ir::Control::While(ir::While { + body, attributes, .. + }) => { + // Removing the `promote_static` attribute bc we don't need it anymore + attributes.remove(ir::NumAttr::PromoteStatic); + // Removing the `bound` attribute bc we don't need it anymore + attributes.remove(ir::NumAttr::Bound); + let sc = self.convert_to_static(body, builder); + let num_repeats = bound_attribute.unwrap_or_else(|| unreachable!("Called convert_to_static on a while loop without a bound")); + let latency = num_repeats * sc.get_latency(); + Self::check_latencies_match(latency, inferred_latency); + ir::StaticControl::Repeat(ir::StaticRepeat { + attributes: std::mem::take(attributes), + body: Box::new(sc), + num_repeats, + latency, + }) + } + ir::Control::If(ir::If { + port, + tbranch, + fbranch, + attributes, + .. + }) => { + // Removing the `promote_static` attribute bc we don't need it anymore + attributes.remove(ir::NumAttr::PromoteStatic); + let static_tbranch = self.convert_to_static(tbranch, builder); + let static_fbranch = self.convert_to_static(fbranch, builder); + let latency = std::cmp::max( + static_tbranch.get_latency(), + static_fbranch.get_latency(), + ); + Self::check_latencies_match(latency, inferred_latency); + ir::StaticControl::static_if( + Rc::clone(port), + Box::new(static_tbranch), + Box::new(static_fbranch), + latency, + ) + } + ir::Control::Static(_) => c.take_static_control(), + ir::Control::Invoke(ir::Invoke { + comp, + inputs, + outputs, + attributes, + comb_group, + ref_cells, + }) => { + assert!( + comb_group.is_none(), + "Shouldn't Promote to Static if there is a Comb Group", + ); + attributes.remove(ir::NumAttr::PromoteStatic); + Self::check_latencies_match(self.static_component_latencies.get( + &comp.borrow().type_name().unwrap_or_else(|| { + unreachable!( + "Already checked that comp is component" + ) + }), + ).unwrap_or_else(|| unreachable!("Called convert_to_static for static invoke that does not have a static component")).get(), inferred_latency); + let s_inv = ir::StaticInvoke { + comp: Rc::clone(comp), + inputs: std::mem::take(inputs), + outputs: std::mem::take(outputs), + latency: inferred_latency, + attributes: std::mem::take(attributes), + ref_cells: std::mem::take(ref_cells), + }; + ir::StaticControl::Invoke(s_inv) + } } - let mut attributes = ir::Attributes::default(); - attributes.insert(ir::NumAttr::Compactable, 1); - ir::Control::Static(StaticControl::Seq(ir::StaticSeq { - stmts: static_seq_st, - attributes, - latency, - })) } - fn construct_static_par( + /// Converts vec of control to vec of static control. + /// All control statements in the vec must be promotable or already static. + fn convert_vec_to_static( &mut self, builder: &mut ir::Builder, - s_stmts: &mut Vec, - ) -> ir::Control { - let mut latency = 0; - let mut static_par_st: Vec = Vec::new(); - for s in std::mem::take(s_stmts) { - match s { - ir::Control::Static(sc) => { - latency = std::cmp::max(latency, sc.get_latency()); - static_par_st.push(sc); - } - ir::Control::Enable(en) => { - let sen = self.construct_static_enable(builder, en); - latency = std::cmp::max(latency, sen.get_latency()); - static_par_st.push(sen); - } - _ => unreachable!("We do not insert non-static controls other than group enables with `promote_static` attribute") + control_vec: Vec, + ) -> Vec { + control_vec + .into_iter() + .map(|mut c| self.convert_to_static(&mut c, builder)) + .collect() + } + + /// Calculates the approximate "size" of the control statements. + /// Tries to approximate the number of dynamic FSM transitions that will occur + fn approx_size(c: &ir::Control) -> u64 { + match c { + ir::Control::Empty(_) => 0, + ir::Control::Enable(_) => APPROX_ENABLE_SIZE, + ir::Control::Seq(ir::Seq { stmts, .. }) + | ir::Control::Par(ir::Par { stmts, .. }) => { + stmts.iter().map(Self::approx_size).sum() + } + ir::Control::Repeat(ir::Repeat { body, .. }) + | ir::Control::While(ir::While { body, .. }) => { + Self::approx_size(body) + APPROX_WHILE_REPEAT_SIZE + } + ir::Control::If(ir::If { + tbranch, fbranch, .. + }) => { + Self::approx_size(tbranch) + + Self::approx_size(fbranch) + + APPROX_IF_SIZE + } + ir::Control::Static(_) => { + // static control appears as one big group to the dynamic FSM + 1 } + ir::Control::Invoke(_) => 1, } - ir::Control::Static(StaticControl::Par(ir::StaticPar { - stmts: static_par_st, - attributes: ir::Attributes::default(), - latency, - })) + } + + /// Uses `approx_size` function to sum the sizes of the control statements + /// in the given vector + fn approx_control_vec_size(v: &[ir::Control]) -> u64 { + v.iter().map(Self::approx_size).sum() + } + + /// First checks if the vec of control statements meets the self.threshold. + /// (That is, whether the combined approx_size of the static_vec is greater) + /// Than the threshold. + /// If so, converts vec of control to a static seq, and returns a vec containing + /// the static seq. + /// Otherwise, just returns the vec without changing it. + fn convert_vec_seq_if_threshold( + &mut self, + builder: &mut ir::Builder, + control_vec: Vec, + ) -> Vec { + if Self::approx_control_vec_size(&control_vec) <= self.threshold { + // Return unchanged vec + return control_vec; + } + // Convert vec to static seq + let s_seq_stmts = self.convert_vec_to_static(builder, control_vec); + let latency = s_seq_stmts.iter().map(|sc| sc.get_latency()).sum(); + let mut sseq = + ir::Control::Static(ir::StaticControl::seq(s_seq_stmts, latency)); + sseq.get_mut_attributes() + .insert(ir::NumAttr::Compactable, 1); + vec![sseq] + } + + /// First checks if the vec of control statements meets the self.threshold. + /// If so, converts vec of control to a static par, and returns a vec containing + /// the static par. + /// Otherwise, just returns the vec without changing it. + fn convert_vec_par_if_threshold( + &mut self, + builder: &mut ir::Builder, + control_vec: Vec, + ) -> Vec { + if Self::approx_control_vec_size(&control_vec) <= self.threshold { + // Return unchanged vec + return control_vec; + } + // Convert vec to static seq + let s_par_stmts = self.convert_vec_to_static(builder, control_vec); + let latency = s_par_stmts + .iter() + .map(|sc| sc.get_latency()) + .max() + .unwrap_or_else(|| unreachable!("empty par block")); + let spar = + ir::Control::Static(ir::StaticControl::par(s_par_stmts, latency)); + vec![spar] } } @@ -545,6 +824,10 @@ impl Visitor for StaticPromotion { } } } + if comp.is_static() { + self.static_component_latencies + .insert(comp.name, comp.latency.unwrap()); + } Ok(Action::Continue) } @@ -614,33 +897,16 @@ impl Visitor for StaticPromotion { s: &mut ir::Invoke, _comp: &mut ir::Component, _sigs: &LibrarySignatures, - comps: &[ir::Component], + _comps: &[ir::Component], ) -> VisResult { - if s.comp.borrow().is_component() { - let name = s.comp.borrow().type_name().unwrap(); - for c in comps { - if c.name == name && c.is_static() { - let emp = ir::Invoke { - comp: Rc::clone(&s.comp), - inputs: Vec::new(), - outputs: Vec::new(), - attributes: ir::Attributes::default(), - comb_group: None, - ref_cells: Vec::new(), - }; - let actual_invoke = std::mem::replace(s, emp); - let s_inv = ir::StaticInvoke { - comp: Rc::clone(&actual_invoke.comp), - inputs: actual_invoke.inputs, - outputs: actual_invoke.outputs, - latency: c.latency.unwrap().get(), - attributes: ir::Attributes::default(), - ref_cells: actual_invoke.ref_cells, - }; - return Ok(Action::change(ir::Control::Static( - ir::StaticControl::Invoke(s_inv), - ))); - } + // Shouldn't promote to static invoke if we have a comb group + if s.comp.borrow().is_component() && s.comb_group.is_none() { + if let Some(latency) = self + .static_component_latencies + .get(&s.comp.borrow().type_name().unwrap()) + { + s.attributes + .insert(ir::NumAttr::PromoteStatic, latency.get()); } } Ok(Action::Continue) @@ -654,45 +920,55 @@ impl Visitor for StaticPromotion { _comps: &[ir::Component], ) -> VisResult { let mut builder = ir::Builder::new(comp, sigs); + let old_stmts = std::mem::take(&mut s.stmts); let mut new_stmts: Vec = Vec::new(); - let mut static_vec: Vec = Vec::new(); - for stmt in std::mem::take(&mut s.stmts) { - if stmt.is_static() - || stmt.has_attribute(ir::NumAttr::PromoteStatic) - { - static_vec.push(stmt); + let mut cur_vec: Vec = Vec::new(); + for stmt in old_stmts { + if Self::can_be_promoted(&stmt) { + cur_vec.push(stmt); } else { - // we do not promote if static_vec contains only 1 single enable - match static_vec.len().cmp(&1) { - std::cmp::Ordering::Equal => { - new_stmts.extend(static_vec); - } - std::cmp::Ordering::Greater => { - let sseq = self.construct_static_seq( - &mut builder, - &mut static_vec, - ); - new_stmts.push(sseq); - } - _ => {} - } + // Accumualte cur_vec into a static seq if it meets threshold + let possibly_promoted_stmts = + self.convert_vec_seq_if_threshold(&mut builder, cur_vec); + new_stmts.extend(possibly_promoted_stmts); + cur_vec = Vec::new(); + // Add the current (non-promotable) stmt new_stmts.push(stmt); - static_vec = Vec::new(); } } - if !static_vec.is_empty() { - if static_vec.len() == 1 { - new_stmts.extend(static_vec); + if new_stmts.is_empty() { + // The entire seq can be promoted + let approx_size: u64 = cur_vec.iter().map(Self::approx_size).sum(); + if approx_size > self.threshold { + // Promote entire seq to a static seq + let s_seq_stmts = + self.convert_vec_to_static(&mut builder, cur_vec); + let latency = + s_seq_stmts.iter().map(|sc| sc.get_latency()).sum(); + let mut sseq = ir::Control::Static(ir::StaticControl::seq( + s_seq_stmts, + latency, + )); + sseq.get_mut_attributes() + .insert(ir::NumAttr::Compactable, 1); + return Ok(Action::change(sseq)); } else { - let sseq = - self.construct_static_seq(&mut builder, &mut static_vec); - if new_stmts.is_empty() { - return Ok(Action::change(sseq)); - } else { - new_stmts.push(sseq); - } + // Doesn't meet threshold. + // Add attribute to seq so parent might promote it. + let inferred_latency = + cur_vec.iter().map(Self::get_inferred_latency).sum(); + s.attributes + .insert(ir::NumAttr::PromoteStatic, inferred_latency); + s.stmts = cur_vec; + return Ok(Action::Continue); } } + // Entire seq is not static, so we're only promoting the last + // bit of it if possible. + let possibly_promoted_stmts = + self.convert_vec_seq_if_threshold(&mut builder, cur_vec); + new_stmts.extend(possibly_promoted_stmts); + let new_seq = ir::Control::Seq(ir::Seq { stmts: new_stmts, attributes: ir::Attributes::default(), @@ -709,19 +985,44 @@ impl Visitor for StaticPromotion { ) -> VisResult { let mut builder = ir::Builder::new(comp, sigs); let mut new_stmts: Vec = Vec::new(); - let (mut s_stmts, d_stmts): (Vec, Vec) = + // Split the par into static and dynamic stmts + let (s_stmts, d_stmts): (Vec, Vec) = s.stmts.drain(..).partition(|s| { s.is_static() || s.get_attributes().has(ir::NumAttr::PromoteStatic) }); - if !s_stmts.is_empty() { - let s_par = self.construct_static_par(&mut builder, &mut s_stmts); - if d_stmts.is_empty() { - return Ok(Action::change(s_par)); + if d_stmts.is_empty() { + // Entire par block can be promoted to static + if Self::approx_control_vec_size(&s_stmts) > self.threshold { + // Promote entire par block to static + let static_par_stmts = + self.convert_vec_to_static(&mut builder, s_stmts); + let latency = static_par_stmts + .iter() + .map(|sc| sc.get_latency()) + .max() + .unwrap_or_else(|| unreachable!("empty par block")); + return Ok(Action::change(ir::Control::Static( + ir::StaticControl::par(static_par_stmts, latency), + ))); } else { - new_stmts.push(s_par); + // Doesn't meet threshold, but add promotion attribute since + // parent might want to promote it. + let inferred_latency = s_stmts + .iter() + .map(Self::get_inferred_latency) + .max() + .unwrap_or_else(|| unreachable!("empty par block")); + s.get_mut_attributes() + .insert(ir::NumAttr::PromoteStatic, inferred_latency); + s.stmts = s_stmts; + return Ok(Action::Continue); } } + // Otherwise just promote the par threads that we can into a static par + let s_stmts_possibly_promoted = + self.convert_vec_par_if_threshold(&mut builder, s_stmts); + new_stmts.extend(s_stmts_possibly_promoted); new_stmts.extend(d_stmts); let new_par = ir::Control::Par(ir::Par { stmts: new_stmts, @@ -733,14 +1034,46 @@ impl Visitor for StaticPromotion { fn finish_if( &mut self, s: &mut ir::If, - _comp: &mut ir::Component, - _sigs: &LibrarySignatures, + comp: &mut ir::Component, + sigs: &LibrarySignatures, _comps: &[ir::Component], ) -> VisResult { - if let Some(sif) = s.make_static() { - return Ok(Action::change(ir::Control::Static( - ir::StaticControl::If(sif), - ))); + let mut builder = ir::Builder::new(comp, sigs); + if Self::can_be_promoted(&s.tbranch) + && Self::can_be_promoted(&s.fbranch) + { + // Both branches can be promoted + let approx_size_if = Self::approx_size(&s.tbranch) + + Self::approx_size(&s.fbranch) + + APPROX_IF_SIZE; + if approx_size_if > self.threshold { + // Meets size threshold so promote to static + let static_tbranch = + self.convert_to_static(&mut s.tbranch, &mut builder); + let static_fbranch = + self.convert_to_static(&mut s.fbranch, &mut builder); + let latency = std::cmp::max( + static_tbranch.get_latency(), + static_fbranch.get_latency(), + ); + return Ok(Action::change(ir::Control::Static( + ir::StaticControl::static_if( + Rc::clone(&s.port), + Box::new(static_tbranch), + Box::new(static_fbranch), + latency, + ), + ))); + } else { + // Doesn't meet size threshold, so attach attribute + // so parent might be able to promote it. + let inferred_max_latency = std::cmp::max( + Self::get_inferred_latency(&s.tbranch), + Self::get_inferred_latency(&s.fbranch), + ); + s.get_mut_attributes() + .insert(ir::NumAttr::PromoteStatic, inferred_max_latency) + } } Ok(Action::Continue) } @@ -749,26 +1082,41 @@ impl Visitor for StaticPromotion { fn finish_while( &mut self, s: &mut ir::While, - _comp: &mut ir::Component, - _sigs: &LibrarySignatures, + comp: &mut ir::Component, + sigs: &LibrarySignatures, _comps: &[ir::Component], ) -> VisResult { - if s.body.is_static() { - // checks body is static and we have an @bound annotation - if let Some(num_repeats) = s.attributes.get(ir::NumAttr::Bound) { - let ir::Control::Static(sc) = s.body.take_control() else { - unreachable!("already checked that body is static"); - }; - let static_repeat = - ir::StaticControl::Repeat(ir::StaticRepeat { - latency: num_repeats * sc.get_latency(), - attributes: s.attributes.clone(), - body: Box::new(sc), + if self.stop_loop { + return Ok(Action::Continue); + } + let mut builder = ir::Builder::new(comp, sigs); + // First check that while loop is bounded + if let Some(num_repeats) = s.get_attributes().get(ir::NumAttr::Bound) { + // Then check that body is static/promotable + if Self::can_be_promoted(&s.body) { + let approx_size = + Self::approx_size(&s.body) + APPROX_WHILE_REPEAT_SIZE; + // Then check that it reaches the threshold + if approx_size > self.threshold { + // Turn repeat into static repeat + let sc = self.convert_to_static(&mut s.body, &mut builder); + let latency = sc.get_latency() * num_repeats; + let static_repeat = ir::StaticControl::repeat( num_repeats, - }); - return Ok(Action::Change(Box::new(ir::Control::Static( - static_repeat, - )))); + latency, + Box::new(sc), + ); + return Ok(Action::Change(Box::new(ir::Control::Static( + static_repeat, + )))); + } else { + // Attach static_promote attribute since parent control may + // want to promote + s.attributes.insert( + ir::NumAttr::PromoteStatic, + num_repeats * Self::get_inferred_latency(&s.body), + ) + } } } Ok(Action::Continue) @@ -778,23 +1126,39 @@ impl Visitor for StaticPromotion { fn finish_repeat( &mut self, s: &mut ir::Repeat, - _comp: &mut ir::Component, - _sigs: &LibrarySignatures, + comp: &mut ir::Component, + sigs: &LibrarySignatures, _comps: &[ir::Component], ) -> VisResult { - if s.body.is_static() { - let ir::Control::Static(sc) = s.body.take_control() else { - unreachable!("already checked that body is static"); - }; - let static_repeat = ir::StaticControl::Repeat(ir::StaticRepeat { - latency: s.num_repeats * sc.get_latency(), - attributes: s.attributes.clone(), - body: Box::new(sc), - num_repeats: s.num_repeats, - }); - return Ok(Action::Change(Box::new(ir::Control::Static( - static_repeat, - )))); + if self.stop_loop { + return Ok(Action::Continue); + } + let mut builder = ir::Builder::new(comp, sigs); + if Self::can_be_promoted(&s.body) { + // Body can be promoted + let approx_size = + Self::approx_size(&s.body) + APPROX_WHILE_REPEAT_SIZE; + if approx_size > self.threshold { + // Meets size threshold, so turn repeat into static repeat + let sc = self.convert_to_static(&mut s.body, &mut builder); + let latency = s.num_repeats * sc.get_latency(); + let static_repeat = ir::StaticControl::repeat( + s.num_repeats, + latency, + Box::new(sc), + ); + return Ok(Action::Change(Box::new(ir::Control::Static( + static_repeat, + )))); + } else { + // Doesn't meet threshold. + // Attach static_promote attribute since parent control may + // want to promote + s.attributes.insert( + ir::NumAttr::PromoteStatic, + s.num_repeats * Self::get_inferred_latency(&s.body), + ) + } } Ok(Action::Continue) } diff --git a/tests/passes/compile-repeat/nested.expect b/tests/passes/compile-repeat/nested.expect index a89f00381..fbc82c45a 100644 --- a/tests/passes/compile-repeat/nested.expect +++ b/tests/passes/compile-repeat/nested.expect @@ -17,14 +17,14 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { r1.write_en = 1'd1; write_r1[done] = r1.done; } - group init_repeat { + group init_repeat<"promote_static"=1> { idx.write_en = 1'd1; idx.in = 2'd0; cond_reg.write_en = 1'd1; cond_reg.in = 1'd1; init_repeat[done] = cond_reg.done & idx.done ? 1'd1; } - group incr_repeat { + group incr_repeat<"promote_static"=1> { adder.left = idx.out; adder.right = 2'd1; lt.left = adder.out; @@ -35,14 +35,14 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { idx.in = adder.out; incr_repeat[done] = cond_reg.done & idx.done ? 1'd1; } - group init_repeat0 { + group init_repeat0<"promote_static"=1> { idx0.write_en = 1'd1; idx0.in = 3'd0; cond_reg0.write_en = 1'd1; cond_reg0.in = 1'd1; init_repeat0[done] = cond_reg0.done & idx0.done ? 1'd1; } - group incr_repeat0 { + group incr_repeat0<"promote_static"=1> { adder0.left = idx0.out; adder0.right = 3'd1; lt0.left = adder0.out; diff --git a/tests/passes/compile-repeat/simple.expect b/tests/passes/compile-repeat/simple.expect index bb0cc5644..feba5eaef 100644 --- a/tests/passes/compile-repeat/simple.expect +++ b/tests/passes/compile-repeat/simple.expect @@ -19,14 +19,14 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { r2.write_en = 1'd1; write_r2[done] = r2.done; } - group init_repeat { + group init_repeat<"promote_static"=1> { idx.write_en = 1'd1; idx.in = 3'd0; cond_reg.write_en = 1'd1; cond_reg.in = 1'd1; init_repeat[done] = cond_reg.done & idx.done ? 1'd1; } - group incr_repeat { + group incr_repeat<"promote_static"=1> { adder.left = idx.out; adder.right = 3'd1; lt.left = adder.out; diff --git a/tests/passes/static-promotion/no_promote_loop.expect b/tests/passes/static-promotion/no_promote_loop.expect new file mode 100644 index 000000000..26a22869a --- /dev/null +++ b/tests/passes/static-promotion/no_promote_loop.expect @@ -0,0 +1,39 @@ +import "primitives/core.futil"; +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + } + wires { + static<1> group A0 { + a.in = 2'd0; + a.write_en = 1'd1; + } + static<1> group B0 { + b.in = 2'd1; + b.write_en = 1'd1; + } + static<1> group C0 { + c.in = 2'd2; + c.write_en = 1'd1; + } + } + control { + seq { + repeat 10 { + @compactable static<4> seq { + A0; + B0; + C0; + C0; + } + } + @compactable static<3> seq { + A0; + B0; + C0; + } + } + } +} diff --git a/tests/passes/static-promotion/no_promote_loop.futil b/tests/passes/static-promotion/no_promote_loop.futil new file mode 100644 index 000000000..36d9575ca --- /dev/null +++ b/tests/passes/static-promotion/no_promote_loop.futil @@ -0,0 +1,42 @@ +// -p well-formed -p static-promotion -p dead-group-removal -x static-promotion:stop_loop + +import "primitives/core.futil"; + +component main() -> () { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + } + + wires { + group A { + a.in = 2'd0; + a.write_en = 1'b1; + A[done] = a.done; + } + + group B { + b.in = 2'd1; + b.write_en = 1'b1; + B[done] = b.done; + } + + group C { + c.in = 2'd2; + c.write_en = 1'b1; + C[done] = c.done; + } + } + + control { + seq { + repeat 10 { + seq { A; B; C; C;} + } + A; + B; + C; + } + } +} diff --git a/tests/passes/static-promotion/promote-nested.expect b/tests/passes/static-promotion/promote-nested.expect new file mode 100644 index 000000000..cb4141c2a --- /dev/null +++ b/tests/passes/static-promotion/promote-nested.expect @@ -0,0 +1,54 @@ +import "primitives/core.futil"; +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + cond_reg = std_reg(1); + r0 = std_reg(2); + } + wires { + group no_upgrade { + r0.write_en = 1'd1; + no_upgrade[done] = r0.done ? 1'd1; + } + static<1> group A0 { + a.in = 2'd0; + a.write_en = 1'd1; + } + static<1> group B0 { + b.in = 2'd1; + b.write_en = 1'd1; + } + static<1> group C0 { + c.in = 2'd2; + c.write_en = 1'd1; + } + } + control { + seq { + @compactable static<4> seq { + static<1> par { + A0; + B0; + } + @compactable static<2> seq { + C0; + C0; + } + static<1> par { + A0; + B0; + } + } + no_upgrade; + static repeat 2 { + @compactable static<3> seq { + A0; + B0; + C0; + } + } + } + } +} diff --git a/tests/passes/static-promotion/promote-nested.futil b/tests/passes/static-promotion/promote-nested.futil new file mode 100644 index 000000000..e3a49fbd5 --- /dev/null +++ b/tests/passes/static-promotion/promote-nested.futil @@ -0,0 +1,59 @@ +// -p well-formed -p static-promotion -p dead-group-removal -x static-promotion:threshold=5 + +import "primitives/core.futil"; + +component main() -> () { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + cond_reg = std_reg(1); + r0 = std_reg(2); + } + + wires { + group A { + a.in = 2'd0; + a.write_en = 1'b1; + A[done] = a.done; + } + + group B { + b.in = 2'd1; + b.write_en = 1'b1; + B[done] = b.done; + } + + group C { + c.in = 2'd2; + c.write_en = 1'b1; + C[done] = c.done; + } + + group no_upgrade { + r0.write_en = 1'd1; + no_upgrade[done] = r0.done ? 1'd1; + } + } + + control { + seq { + seq { + par {A; B;} + seq {C; C;} + par {A; B;} + } + no_upgrade; + @bound(2) while cond_reg.out { + seq { + A; + B; + C; + } + } + } + + + + } +} diff --git a/tests/passes/static-promotion/threshold.expect b/tests/passes/static-promotion/threshold.expect new file mode 100644 index 000000000..976c597a7 --- /dev/null +++ b/tests/passes/static-promotion/threshold.expect @@ -0,0 +1,33 @@ +import "primitives/core.futil"; +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + } + wires { + group A<"promote_static"=1> { + a.in = 2'd0; + a.write_en = 1'd1; + A[done] = a.done; + } + group B<"promote_static"=1> { + b.in = 2'd1; + b.write_en = 1'd1; + B[done] = b.done; + } + group C<"promote_static"=1> { + c.in = 2'd2; + c.write_en = 1'd1; + C[done] = c.done; + } + } + control { + @promote_static(4) seq { + @promote_static A; + @promote_static B; + @promote_static C; + @promote_static C; + } + } +} diff --git a/tests/passes/static-promotion/threshold.futil b/tests/passes/static-promotion/threshold.futil new file mode 100644 index 000000000..10375bdec --- /dev/null +++ b/tests/passes/static-promotion/threshold.futil @@ -0,0 +1,36 @@ +// -p well-formed -p static-promotion -p dead-group-removal -x static-promotion:threshold=4 + +import "primitives/core.futil"; + +component main() -> () { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + } + + wires { + group A { + a.in = 2'd0; + a.write_en = 1'b1; + A[done] = a.done; + } + + group B { + b.in = 2'd1; + b.write_en = 1'b1; + B[done] = b.done; + } + + group C { + c.in = 2'd2; + c.write_en = 1'b1; + C[done] = c.done; + } + } + + control { + // No promotion because size must be *greater* than the threshold + seq { A; B; C; C;} + } +} diff --git a/tests/passes/static-promotion/upgrade-bound.expect b/tests/passes/static-promotion/upgrade-bound.expect index fa3f416a1..9fc156ea6 100644 --- a/tests/passes/static-promotion/upgrade-bound.expect +++ b/tests/passes/static-promotion/upgrade-bound.expect @@ -21,7 +21,7 @@ static<15> component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done don } } control { - @bound(5) static repeat 5 { + static repeat 5 { @compactable static<3> seq { A0; B0;