Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify logic on static designs #1775

Merged
merged 10 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions calyx-opt/src/passes/simplify_static_guards.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ impl SimplifyStaticGuards {
cur_anded_intervals: &mut Vec<(u64, u64)>,
) -> Option<ir::Guard<ir::StaticTiming>> {
match g {
ir::Guard::Not(_)
| ir::Guard::Or(_, _)
| ir::Guard::True
| ir::Guard::CompOp(_, _, _)
| ir::Guard::Port(_) => Some(g),
ir::Guard::And(g1, g2) => {
// recursively call separate_anded_intervals on g1 and g2
let rest_g1 =
Expand All @@ -65,6 +60,11 @@ impl SimplifyStaticGuards {
cur_anded_intervals.push(static_timing_info.get_interval());
None
}
ir::Guard::True
| ir::Guard::CompOp(_, _, _)
| ir::Guard::Not(_)
| ir::Guard::Or(_, _)
| ir::Guard::Port(_) => Some(g),
}
}

Expand All @@ -75,7 +75,7 @@ impl SimplifyStaticGuards {
/// For example: (port.out | !port1.out) & (port2.out == port3.out) & %[2:8] & %[5:10] ?
/// becomes (port.out | !port1.out) & (port2.out == port3.out) & %[5:8] ?
/// by "combining: %[2:8] & %[5:10]
fn simplify_guard(
fn simplify_anded_guards(
guard: ir::Guard<ir::StaticTiming>,
group_latency: u64,
) -> ir::Guard<ir::StaticTiming> {
Expand Down Expand Up @@ -121,6 +121,30 @@ impl SimplifyStaticGuards {
(Some(rg), Some(ig)) => ir::Guard::And(Box::new(rg), Box::new(ig)),
}
}

fn simplify_guard(
guard: ir::Guard<ir::StaticTiming>,
group_latency: u64,
) -> ir::Guard<ir::StaticTiming> {
match guard {
ir::Guard::Not(g) => ir::Guard::Not(Box::new(
Self::simplify_guard(*g, group_latency),
)),
ir::Guard::Or(g1, g2) => ir::Guard::Or(
Box::new(Self::simplify_guard(*g1, group_latency)),
Box::new(Self::simplify_guard(*g2, group_latency)),
),
ir::Guard::And(_, _) => {
Self::simplify_anded_guards(guard, group_latency)
}
ir::Guard::Info(_) => {
Self::simplify_anded_guards(guard, group_latency)
}
ir::Guard::Port(_)
| ir::Guard::True
| ir::Guard::CompOp(_, _, _) => guard,
}
}
}

impl Visitor for SimplifyStaticGuards {
Expand Down
119 changes: 82 additions & 37 deletions calyx-opt/src/passes/static_promotion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ pub struct StaticPromotion {
/// Threshold for promotion
threshold: u64,
/// Whether we should stop promoting when we see a loop.
stop_loop: bool,
cycle_limit: Option<u64>,
}

// Override constructor to build latency_data information from the primitives
Expand Down Expand Up @@ -177,13 +177,13 @@ impl ConstructVisitor for StaticPromotion {
}
latency_data.insert(prim.name, GoDone::new(go_ports));
}
let (threshold, stop_loop) = Self::get_threshold(ctx);
let (threshold, cycle_limit) = Self::get_threshold(ctx);
Ok(StaticPromotion {
latency_data,
static_group_name: HashMap::new(),
static_component_latencies: HashMap::new(),
threshold,
stop_loop,
cycle_limit,
})
}

Expand All @@ -206,7 +206,7 @@ impl Named for StaticPromotion {
impl StaticPromotion {
// Looks through ctx to get the given command line threshold.
// Default threshold = 1
fn get_threshold(ctx: &ir::Context) -> (u64, bool)
fn get_threshold(ctx: &ir::Context) -> (u64, Option<u64>)
where
Self: Named,
{
Expand All @@ -226,13 +226,24 @@ impl StaticPromotion {
})
.collect();

let mut stop_loop = false;
given_opts.iter().for_each(|arg| {
if *arg == "stop_loop" {
stop_loop = true
// searching for "-x static-promotion:cycle-limit=200" and getting back "200"
let cycle_limit_str: Option<&str> = given_opts.iter().find_map(|arg| {
let split: Vec<&str> = arg.split('=').collect();
if let Some(str) = split.first() {
if str == &"cycle-limit" {
return Some(split[1]);
}
}
None
});

// Default to None. There may be a more idiomatic way to do this.
let cycle_limit = match cycle_limit_str.unwrap_or("None").parse::<u64>()
{
Ok(n) => Some(n),
Err(_) => None,
};

// searching for "-x static-promotion:threshold=1" and getting back "1"
let threshold: Option<&str> = given_opts.iter().find_map(|arg| {
let split: Vec<&str> = arg.split('=').collect();
Expand All @@ -246,9 +257,10 @@ impl StaticPromotion {

// Need to convert string argument into int argument
// Always default to threshold=1
// Default cycle limit = 2^25 = 33554432
(
threshold.unwrap_or("1").parse::<u64>().unwrap_or(1),
stop_loop,
cycle_limit,
)
}

Expand Down Expand Up @@ -502,6 +514,13 @@ impl StaticPromotion {
c.is_static() || c.has_attribute(ir::NumAttr::PromoteStatic)
}

fn within_cycle_limit(&self, latency: u64) -> bool {
if self.cycle_limit.is_none() {
return true;
}
latency < self.cycle_limit.unwrap()
}

/// If we've already constructed the static group then use the already existing
/// group. Otherwise construct `static group` and then return that.
fn construct_static_group(
Expand Down Expand Up @@ -743,18 +762,23 @@ impl StaticPromotion {
v.iter().map(Self::approx_size).sum()
}

/// First checks if the vec of control statements meets the self.threshold.
/// First checks if the vec of control statements satsifies the threshold
/// and cycle count threshold
/// (That is, whether the combined approx_size of the static_vec is greater)
/// Than the threshold.
/// than the threshold and cycle count is less than cycle limit).
/// If so, converts vec of control to a static seq, and returns a vec containing
/// the static seq.
/// Otherwise, just returns the vec without changing it.
fn convert_vec_seq_if_threshold(
fn convert_vec_seq_if_sat(
&mut self,
builder: &mut ir::Builder,
control_vec: Vec<ir::Control>,
) -> Vec<ir::Control> {
if Self::approx_control_vec_size(&control_vec) <= self.threshold {
if Self::approx_control_vec_size(&control_vec) <= self.threshold
|| !self.within_cycle_limit(
control_vec.iter().map(Self::get_inferred_latency).sum(),
)
{
// Return unchanged vec
return control_vec;
}
Expand All @@ -768,16 +792,25 @@ impl StaticPromotion {
vec![sseq]
}

/// First checks if the vec of control statements meets the self.threshold.
/// First checks if the vec of control statements meets the self.threshold
/// and is within self.cycle_limit
/// If so, converts vec of control to a static par, and returns a vec containing
/// the static par.
/// Otherwise, just returns the vec without changing it.
fn convert_vec_par_if_threshold(
fn convert_vec_par_if_sat(
&mut self,
builder: &mut ir::Builder,
control_vec: Vec<ir::Control>,
) -> Vec<ir::Control> {
if Self::approx_control_vec_size(&control_vec) <= self.threshold {
if Self::approx_control_vec_size(&control_vec) <= self.threshold
|| !self.within_cycle_limit(
control_vec
.iter()
.map(Self::get_inferred_latency)
.max()
.unwrap_or_else(|| unreachable!("Non Empty Par Block")),
)
{
// Return unchanged vec
return control_vec;
}
Expand Down Expand Up @@ -933,17 +966,22 @@ impl Visitor for StaticPromotion {
} else {
// Accumualte cur_vec into a static seq if it meets threshold
let possibly_promoted_stmts =
self.convert_vec_seq_if_threshold(&mut builder, cur_vec);
self.convert_vec_seq_if_sat(&mut builder, cur_vec);
new_stmts.extend(possibly_promoted_stmts);
cur_vec = Vec::new();
// Add the current (non-promotable) stmt
new_stmts.push(stmt);
// New cur_vec
cur_vec = Vec::new();
}
}
if new_stmts.is_empty() {
// The entire seq can be promoted
let approx_size: u64 = cur_vec.iter().map(Self::approx_size).sum();
if approx_size > self.threshold {
if approx_size > self.threshold
&& self.within_cycle_limit(
cur_vec.iter().map(Self::get_inferred_latency).sum(),
)
{
// Promote entire seq to a static seq
let s_seq_stmts =
self.convert_vec_to_static(&mut builder, cur_vec);
Expand All @@ -970,7 +1008,7 @@ impl Visitor for StaticPromotion {
// Entire seq is not static, so we're only promoting the last
// bit of it if possible.
let possibly_promoted_stmts =
self.convert_vec_seq_if_threshold(&mut builder, cur_vec);
self.convert_vec_seq_if_sat(&mut builder, cur_vec);
new_stmts.extend(possibly_promoted_stmts);

let new_seq = ir::Control::Seq(ir::Seq {
Expand All @@ -997,7 +1035,15 @@ impl Visitor for StaticPromotion {
});
if d_stmts.is_empty() {
// Entire par block can be promoted to static
if Self::approx_control_vec_size(&s_stmts) > self.threshold {
if Self::approx_control_vec_size(&s_stmts) > self.threshold
&& self.within_cycle_limit(
s_stmts
.iter()
.map(Self::get_inferred_latency)
.max()
.unwrap_or_else(|| unreachable!("Empty Par Block")),
)
{
// Promote entire par block to static
let static_par_stmts =
self.convert_vec_to_static(&mut builder, s_stmts);
Expand Down Expand Up @@ -1025,7 +1071,7 @@ impl Visitor for StaticPromotion {
}
// Otherwise just promote the par threads that we can into a static par
let s_stmts_possibly_promoted =
self.convert_vec_par_if_threshold(&mut builder, s_stmts);
self.convert_vec_par_if_sat(&mut builder, s_stmts);
new_stmts.extend(s_stmts_possibly_promoted);
new_stmts.extend(d_stmts);
let new_par = ir::Control::Par(ir::Par {
Expand All @@ -1050,16 +1096,18 @@ impl Visitor for StaticPromotion {
let approx_size_if = Self::approx_size(&s.tbranch)
+ Self::approx_size(&s.fbranch)
+ APPROX_IF_SIZE;
if approx_size_if > self.threshold {
let latency = std::cmp::max(
Self::get_inferred_latency(&s.tbranch),
Self::get_inferred_latency(&s.fbranch),
);
if approx_size_if > self.threshold
&& self.within_cycle_limit(latency)
{
// Meets size threshold so promote to static
let static_tbranch =
self.convert_to_static(&mut s.tbranch, &mut builder);
let static_fbranch =
self.convert_to_static(&mut s.fbranch, &mut builder);
let latency = std::cmp::max(
static_tbranch.get_latency(),
static_fbranch.get_latency(),
);
return Ok(Action::change(ir::Control::Static(
ir::StaticControl::static_if(
Rc::clone(&s.port),
Expand Down Expand Up @@ -1090,21 +1138,20 @@ impl Visitor for StaticPromotion {
sigs: &LibrarySignatures,
_comps: &[ir::Component],
) -> VisResult {
if self.stop_loop {
return Ok(Action::Continue);
}
let mut builder = ir::Builder::new(comp, sigs);
// First check that while loop is bounded
if let Some(num_repeats) = s.get_attributes().get(ir::NumAttr::Bound) {
// Then check that body is static/promotable
if Self::can_be_promoted(&s.body) {
let approx_size =
Self::approx_size(&s.body) + APPROX_WHILE_REPEAT_SIZE;
let latency = Self::get_inferred_latency(&s.body) * num_repeats;
// Then check that it reaches the threshold
if approx_size > self.threshold {
if approx_size > self.threshold
&& self.within_cycle_limit(latency)
{
// Turn repeat into static repeat
let sc = self.convert_to_static(&mut s.body, &mut builder);
let latency = sc.get_latency() * num_repeats;
let static_repeat = ir::StaticControl::repeat(
num_repeats,
latency,
Expand Down Expand Up @@ -1134,18 +1181,16 @@ impl Visitor for StaticPromotion {
sigs: &LibrarySignatures,
_comps: &[ir::Component],
) -> VisResult {
if self.stop_loop {
return Ok(Action::Continue);
}
let mut builder = ir::Builder::new(comp, sigs);
if Self::can_be_promoted(&s.body) {
// Body can be promoted
let approx_size =
Self::approx_size(&s.body) + APPROX_WHILE_REPEAT_SIZE;
if approx_size > self.threshold {
let latency = Self::get_inferred_latency(&s.body) * s.num_repeats;
if approx_size > self.threshold && self.within_cycle_limit(latency)
{
// Meets size threshold, so turn repeat into static repeat
let sc = self.convert_to_static(&mut s.body, &mut builder);
let latency = s.num_repeats * sc.get_latency();
let static_repeat = ir::StaticControl::repeat(
s.num_repeats,
latency,
Expand Down
4 changes: 2 additions & 2 deletions examples/futil/dot-product.expect
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
B0.clk = clk;
B0.addr0 = fsm.out == 4'd0 & early_reset_static_seq_go.out ? i0.out;
B0.reset = reset;
B_read0_0.write_en = (fsm.out == 4'd0 & fsm.out < 4'd7 | fsm.out == 4'd5 & fsm.out < 4'd7) & early_reset_static_seq_go.out ? 1'd1;
B_read0_0.write_en = (fsm.out == 4'd0 | fsm.out == 4'd5) & early_reset_static_seq_go.out ? 1'd1;
B_read0_0.clk = clk;
B_read0_0.reset = reset;
B_read0_0.in = fsm.out == 4'd0 & early_reset_static_seq_go.out ? B0.read_data;
Expand All @@ -109,7 +109,7 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
early_reset_static_seq_done.in = ud0.out;
tdcc_done.in = fsm0.out == 2'd3 ? 1'd1;
while_wrapper_early_reset_static_seq_done.in = !comb_reg.out & fsm.out == 4'd0 ? 1'd1;
A_read0_0.write_en = (fsm.out == 4'd0 & fsm.out < 4'd7 | fsm.out == 4'd4 & fsm.out < 4'd7) & early_reset_static_seq_go.out ? 1'd1;
A_read0_0.write_en = (fsm.out == 4'd0 | fsm.out == 4'd4) & early_reset_static_seq_go.out ? 1'd1;
A_read0_0.clk = clk;
A_read0_0.reset = reset;
A_read0_0.in = fsm.out == 4'd0 & early_reset_static_seq_go.out ? A0.read_data;
Expand Down
16 changes: 8 additions & 8 deletions tests/passes/simplify-static-guards/basic.futil
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
}
wires {
static<10> group my_group {
a.write_en = (%[2:3] | lt.out) & %[1:5] ? 1'd1; // don't simplify
b.write_en = %[2:3] & (lt.out | gt.out) & %[1:5] ? 1'd1; // %[1:5] is redundant
a.write_en = (%[2:3] | lt.out) & %[1:5] ? 1'd1; // don't simplify
b.write_en = %[2:3] & (lt.out | gt.out) & %[1:5] ? 1'd1; // %[1:5] is redundant
c.write_en = %[2:5] & (%[5:7] | lt.out) & %[3:7] & %[4:10] ? 1'd1; // %[5:7] shouldn't change, but can simplify rest to %[4:5]
d.write_en = %[2:5] & %[6:9] ? 1'd1; // assignment is false
e.write_en = %[0:10] & lt.out ? 1'd1; // no static timing necesary, since %[0:10] is same as group
a.in = 32'd1;
b.in = 32'd2;
c.in = 32'd3;
d.in = 32'd4;
d.write_en = %[2:5] & %[6:9] ? 1'd1; // assignment is false
e.write_en = %[0:10] & lt.out ? 1'd1; // no static timing necesary, since %[0:10] is same as group
a.in = 32'd1;
b.in = 32'd2;
c.in = 32'd3;
d.in = 32'd4;
}
}

Expand Down
Loading