Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedule Compaction Minimizes Par Threads #1774

Merged
merged 10 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 82 additions & 30 deletions calyx-opt/src/passes/schedule_compaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,54 +58,106 @@ impl Visitor for ScheduleCompaction {

if let Ok(order) = algo::toposort(&total_order, None) {
let mut total_time: u64 = 0;
let mut stmts: Vec<ir::StaticControl> = Vec::new();

// First we build the schedule.

for i in order {
let mut start: u64 = 0;
for node in dependency.get(&i).unwrap() {
let allow_start = schedule[node] + latency_map[node];
if allow_start > start {
start = allow_start;
}
}
// Start time is when the latest dependency finishes
let start = dependency
.get(&i)
.unwrap()
.iter()
.map(|node| schedule[node] + latency_map[node])
.max()
.unwrap_or(0);
schedule.insert(i, start);
total_time = std::cmp::max(start + latency_map[&i], total_time);
}

// We sort the schedule by start time.
let mut sorted_schedule: Vec<(NodeIndex, u64)> =
schedule.into_iter().collect();
sorted_schedule
.sort_by(|(k1, v1), (k2, v2)| (v1, k1).cmp(&(v2, k2)));
// Threads for the static par, where each entry is (thread, thread_latency)
let mut par_threads: Vec<(Vec<ir::StaticControl>, u64)> =
Vec::new();

// We encode the schedule attempting to minimize the number of
// par threads.
'outer: for (i, start) in sorted_schedule {
let control = total_order[i].take().unwrap();
let mut st_seq_stmts: Vec<ir::StaticControl> = Vec::new();
for (thread, thread_latency) in par_threads.iter_mut() {
if *thread_latency <= start {
if *thread_latency < start {
// Might need a no-op group so the schedule starts correctly
let no_op = builder.add_static_group(
"no-op",
start - *thread_latency,
);
thread.push(ir::StaticControl::Enable(
ir::StaticEnable {
group: no_op,
attributes: ir::Attributes::default(),
},
));
*thread_latency = start;
}
thread.push(control);
*thread_latency += latency_map[&i];
continue 'outer;
}
}
// We must create a new par thread.
if start > 0 {
// If start > 0, then we must add a delay to the start of the
// group.
let no_op = builder.add_static_group("no-op", start);

st_seq_stmts.push(ir::StaticControl::Enable(
ir::StaticEnable {
let no_op_enable =
ir::StaticControl::Enable(ir::StaticEnable {
group: no_op,
attributes: ir::Attributes::default(),
},
});
par_threads.push((
vec![no_op_enable, control],
start + latency_map[&i],
));
} else {
par_threads.push((vec![control], latency_map[&i]));
}
if start + latency_map[&i] > total_time {
total_time = start + latency_map[&i];
}
}

st_seq_stmts.push(control);
stmts.push(ir::StaticControl::Seq(ir::StaticSeq {
stmts: st_seq_stmts,
// Turn Vec<ir::StaticControl> -> StaticSeq
let mut par_control_threads: Vec<ir::StaticControl> = Vec::new();
for (thread, thread_latency) in par_threads {
par_control_threads.push(ir::StaticControl::Seq(
ir::StaticSeq {
stmts: thread,
attributes: ir::Attributes::default(),
latency: thread_latency,
},
));
}
// Double checking that we have built the static par correctly.
let max = par_control_threads.iter().map(|c| c.get_latency()).max();
assert!(max.unwrap() == total_time, "The schedule expects latency {}. The static par that was built has latency {}", total_time, max.unwrap());

if par_control_threads.len() == 1 {
let c = Vec::pop(&mut par_control_threads).unwrap();
Ok(Action::static_change(c))
} else {
let s_par = ir::StaticControl::Par(ir::StaticPar {
stmts: par_control_threads,
attributes: ir::Attributes::default(),
latency: start + latency_map[&i],
}));
latency: total_time,
});
Ok(Action::static_change(s_par))
}

let s_par = ir::StaticControl::Par(ir::StaticPar {
stmts,
attributes: ir::Attributes::default(),
latency: total_time,
});
return Ok(Action::static_change(s_par));
} else {
println!(
panic!(
"Error when producing topo sort. Dependency graph has a cycle."
);
}
Ok(Action::Continue)
}

fn finish_static_repeat(
Expand Down
78 changes: 39 additions & 39 deletions examples/futil/dot-product.expect
Original file line number Diff line number Diff line change
Expand Up @@ -28,92 +28,92 @@ component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) {
@generated invoke0_done = std_wire(1);
@generated early_reset_cond00_go = std_wire(1);
@generated early_reset_cond00_done = std_wire(1);
@generated early_reset_static_par_go = std_wire(1);
@generated early_reset_static_par_done = std_wire(1);
@generated early_reset_static_seq_go = std_wire(1);
@generated early_reset_static_seq_done = std_wire(1);
@generated wrapper_early_reset_cond00_go = std_wire(1);
@generated wrapper_early_reset_cond00_done = std_wire(1);
@generated while_wrapper_early_reset_static_par_go = std_wire(1);
@generated while_wrapper_early_reset_static_par_done = std_wire(1);
@generated while_wrapper_early_reset_static_seq_go = std_wire(1);
@generated while_wrapper_early_reset_static_seq_done = std_wire(1);
@generated tdcc_go = std_wire(1);
@generated tdcc_done = std_wire(1);
}
wires {
i0.write_en = invoke0_go.out | fsm.out == 4'd1 & early_reset_static_par_go.out ? 1'd1;
i0.write_en = invoke0_go.out | fsm.out == 4'd1 & early_reset_static_seq_go.out ? 1'd1;
i0.clk = clk;
i0.reset = reset;
i0.in = fsm.out == 4'd1 & early_reset_static_par_go.out ? add1.out;
i0.in = fsm.out == 4'd1 & early_reset_static_seq_go.out ? add1.out;
i0.in = invoke0_go.out ? const0.out;
early_reset_cond00_go.in = wrapper_early_reset_cond00_go.out ? 1'd1;
add1.left = fsm.out == 4'd1 & early_reset_static_par_go.out ? i0.out;
add1.right = fsm.out == 4'd1 & early_reset_static_par_go.out ? const3.out;
add1.left = fsm.out == 4'd1 & early_reset_static_seq_go.out ? i0.out;
add1.right = fsm.out == 4'd1 & early_reset_static_seq_go.out ? const3.out;
done = tdcc_done.out ? 1'd1;
fsm.write_en = early_reset_cond00_go.out | early_reset_static_par_go.out ? 1'd1;
fsm.write_en = early_reset_cond00_go.out | early_reset_static_seq_go.out ? 1'd1;
fsm.clk = clk;
fsm.reset = reset;
fsm.in = fsm.out != 4'd0 & early_reset_cond00_go.out ? adder.out;
fsm.in = fsm.out == 4'd0 & early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? 4'd0;
fsm.in = fsm.out != 4'd7 & early_reset_static_par_go.out ? adder0.out;
fsm.in = fsm.out == 4'd0 & early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? 4'd0;
fsm.in = fsm.out != 4'd7 & early_reset_static_seq_go.out ? adder0.out;
adder.left = early_reset_cond00_go.out ? fsm.out;
adder.right = early_reset_cond00_go.out ? 4'd1;
add0.left = fsm.out == 4'd6 & early_reset_static_par_go.out ? v0.read_data;
add0.right = fsm.out == 4'd6 & early_reset_static_par_go.out ? B_read0_0.out;
v0.write_en = fsm.out == 4'd6 & early_reset_static_par_go.out ? 1'd1;
add0.left = fsm.out == 4'd6 & early_reset_static_seq_go.out ? v0.read_data;
add0.right = fsm.out == 4'd6 & early_reset_static_seq_go.out ? B_read0_0.out;
v0.write_en = fsm.out == 4'd6 & early_reset_static_seq_go.out ? 1'd1;
v0.clk = clk;
v0.addr0 = fsm.out == 4'd6 & early_reset_static_par_go.out ? const2.out;
v0.addr0 = fsm.out == 4'd6 & early_reset_static_seq_go.out ? const2.out;
v0.reset = reset;
v0.write_data = fsm.out == 4'd6 & early_reset_static_par_go.out ? add0.out;
comb_reg.write_en = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? 1'd1;
v0.write_data = fsm.out == 4'd6 & early_reset_static_seq_go.out ? add0.out;
comb_reg.write_en = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? 1'd1;
comb_reg.clk = clk;
comb_reg.reset = reset;
comb_reg.in = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? le0.out;
comb_reg.in = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? le0.out;
early_reset_cond00_done.in = ud.out;
while_wrapper_early_reset_static_par_go.in = !while_wrapper_early_reset_static_par_done.out & fsm0.out == 2'd2 & tdcc_go.out ? 1'd1;
while_wrapper_early_reset_static_seq_go.in = !while_wrapper_early_reset_static_seq_done.out & fsm0.out == 2'd2 & tdcc_go.out ? 1'd1;
invoke0_go.in = !invoke0_done.out & fsm0.out == 2'd0 & tdcc_go.out ? 1'd1;
while_wrapper_early_reset_static_par_done.in = !comb_reg.out & fsm.out == 4'd0 ? 1'd1;
tdcc_go.in = go;
A0.clk = clk;
A0.addr0 = fsm.out == 4'd0 & early_reset_static_par_go.out ? i0.out;
A0.addr0 = fsm.out == 4'd0 & early_reset_static_seq_go.out ? i0.out;
A0.reset = reset;
fsm0.write_en = fsm0.out == 2'd3 | fsm0.out == 2'd0 & invoke0_done.out & tdcc_go.out | fsm0.out == 2'd1 & wrapper_early_reset_cond00_done.out & tdcc_go.out | fsm0.out == 2'd2 & while_wrapper_early_reset_static_par_done.out & tdcc_go.out ? 1'd1;
fsm0.write_en = fsm0.out == 2'd3 | fsm0.out == 2'd0 & invoke0_done.out & tdcc_go.out | fsm0.out == 2'd1 & wrapper_early_reset_cond00_done.out & tdcc_go.out | fsm0.out == 2'd2 & while_wrapper_early_reset_static_seq_done.out & tdcc_go.out ? 1'd1;
fsm0.clk = clk;
fsm0.reset = reset;
fsm0.in = fsm0.out == 2'd0 & invoke0_done.out & tdcc_go.out ? 2'd1;
fsm0.in = fsm0.out == 2'd3 ? 2'd0;
fsm0.in = fsm0.out == 2'd2 & while_wrapper_early_reset_static_par_done.out & tdcc_go.out ? 2'd3;
fsm0.in = fsm0.out == 2'd2 & while_wrapper_early_reset_static_seq_done.out & tdcc_go.out ? 2'd3;
fsm0.in = fsm0.out == 2'd1 & wrapper_early_reset_cond00_done.out & tdcc_go.out ? 2'd2;
mult_pipe0.clk = clk;
mult_pipe0.left = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_par_go.out ? A_read0_0.out;
mult_pipe0.go = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_par_go.out ? 1'd1;
mult_pipe0.left = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_seq_go.out ? A_read0_0.out;
mult_pipe0.go = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_seq_go.out ? 1'd1;
mult_pipe0.reset = reset;
mult_pipe0.right = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_par_go.out ? B_read0_0.out;
adder0.left = early_reset_static_par_go.out ? fsm.out;
adder0.right = early_reset_static_par_go.out ? 4'd1;
mult_pipe0.right = fsm.out >= 4'd1 & fsm.out < 4'd4 & early_reset_static_seq_go.out ? B_read0_0.out;
adder0.left = early_reset_static_seq_go.out ? fsm.out;
adder0.right = early_reset_static_seq_go.out ? 4'd1;
invoke0_done.in = i0.done;
early_reset_static_par_done.in = ud0.out;
le0.left = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? i0.out;
le0.right = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_par_go.out ? const1.out;
early_reset_static_seq_go.in = while_wrapper_early_reset_static_seq_go.out ? 1'd1;
le0.left = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? i0.out;
le0.right = early_reset_cond00_go.out | fsm.out == 4'd7 & early_reset_static_seq_go.out ? const1.out;
signal_reg.write_en = fsm.out == 4'd0 & signal_reg.out | fsm.out == 4'd0 & !signal_reg.out & wrapper_early_reset_cond00_go.out ? 1'd1;
signal_reg.clk = clk;
signal_reg.reset = reset;
signal_reg.in = fsm.out == 4'd0 & !signal_reg.out & wrapper_early_reset_cond00_go.out ? 1'd1;
signal_reg.in = fsm.out == 4'd0 & signal_reg.out ? 1'd0;
B0.clk = clk;
B0.addr0 = fsm.out == 4'd0 & early_reset_static_par_go.out ? i0.out;
B0.addr0 = fsm.out == 4'd0 & early_reset_static_seq_go.out ? i0.out;
B0.reset = reset;
B_read0_0.write_en = (fsm.out == 4'd0 | fsm.out == 4'd5 & fsm.out < 4'd6) & early_reset_static_par_go.out ? 1'd1;
B_read0_0.write_en = (fsm.out == 4'd0 & fsm.out < 4'd7 | fsm.out == 4'd5 & fsm.out < 4'd7) & early_reset_static_seq_go.out ? 1'd1;
B_read0_0.clk = clk;
B_read0_0.reset = reset;
B_read0_0.in = fsm.out == 4'd0 & early_reset_static_par_go.out ? B0.read_data;
B_read0_0.in = fsm.out == 4'd5 & early_reset_static_par_go.out ? A_read0_0.out;
B_read0_0.in = fsm.out == 4'd0 & early_reset_static_seq_go.out ? B0.read_data;
B_read0_0.in = fsm.out == 4'd5 & early_reset_static_seq_go.out ? A_read0_0.out;
wrapper_early_reset_cond00_go.in = !wrapper_early_reset_cond00_done.out & fsm0.out == 2'd1 & tdcc_go.out ? 1'd1;
wrapper_early_reset_cond00_done.in = fsm.out == 4'd0 & signal_reg.out ? 1'd1;
early_reset_static_seq_done.in = ud0.out;
tdcc_done.in = fsm0.out == 2'd3 ? 1'd1;
early_reset_static_par_go.in = while_wrapper_early_reset_static_par_go.out ? 1'd1;
A_read0_0.write_en = (fsm.out == 4'd0 | fsm.out == 4'd4 & fsm.out >= 4'd1 & fsm.out < 4'd5 & fsm.out < 4'd5) & early_reset_static_par_go.out ? 1'd1;
while_wrapper_early_reset_static_seq_done.in = !comb_reg.out & fsm.out == 4'd0 ? 1'd1;
A_read0_0.write_en = (fsm.out == 4'd0 & fsm.out < 4'd7 | fsm.out == 4'd4 & fsm.out < 4'd7) & early_reset_static_seq_go.out ? 1'd1;
A_read0_0.clk = clk;
A_read0_0.reset = reset;
A_read0_0.in = fsm.out == 4'd0 & early_reset_static_par_go.out ? A0.read_data;
A_read0_0.in = fsm.out == 4'd4 & early_reset_static_par_go.out ? mult_pipe0.out;
A_read0_0.in = fsm.out == 4'd0 & early_reset_static_seq_go.out ? A0.read_data;
A_read0_0.in = fsm.out == 4'd4 & early_reset_static_seq_go.out ? mult_pipe0.out;
}
control {}
}
32 changes: 16 additions & 16 deletions examples/futil/simple.expect
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ static<5> component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done
@generated ud = undef(1);
@generated adder = std_add(3);
@generated signal_reg = std_reg(1);
@generated early_reset_static_par_go = std_wire(1);
@generated early_reset_static_par_done = std_wire(1);
@generated wrapper_early_reset_static_par_go = std_wire(1);
@generated wrapper_early_reset_static_par_done = std_wire(1);
@generated early_reset_static_seq_go = std_wire(1);
@generated early_reset_static_seq_done = std_wire(1);
@generated wrapper_early_reset_static_seq_go = std_wire(1);
@generated wrapper_early_reset_static_seq_done = std_wire(1);
}
wires {
done = wrapper_early_reset_static_par_done.out ? 1'd1;
fsm.write_en = early_reset_static_par_go.out ? 1'd1;
done = wrapper_early_reset_static_seq_done.out ? 1'd1;
fsm.write_en = early_reset_static_seq_go.out ? 1'd1;
fsm.clk = clk;
fsm.reset = reset;
fsm.in = fsm.out != 3'd4 & early_reset_static_par_go.out ? adder.out;
fsm.in = fsm.out == 3'd4 & early_reset_static_par_go.out ? 3'd0;
adder.left = early_reset_static_par_go.out ? fsm.out;
adder.right = early_reset_static_par_go.out ? 3'd1;
wrapper_early_reset_static_par_go.in = go;
wrapper_early_reset_static_par_done.in = fsm.out == 3'd0 & signal_reg.out ? 1'd1;
early_reset_static_par_done.in = ud.out;
signal_reg.write_en = fsm.out == 3'd0 & signal_reg.out | fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_par_go.out ? 1'd1;
fsm.in = fsm.out != 3'd4 & early_reset_static_seq_go.out ? adder.out;
fsm.in = fsm.out == 3'd4 & early_reset_static_seq_go.out ? 3'd0;
adder.left = early_reset_static_seq_go.out ? fsm.out;
adder.right = early_reset_static_seq_go.out ? 3'd1;
wrapper_early_reset_static_seq_done.in = fsm.out == 3'd0 & signal_reg.out ? 1'd1;
early_reset_static_seq_go.in = wrapper_early_reset_static_seq_go.out ? 1'd1;
signal_reg.write_en = fsm.out == 3'd0 & signal_reg.out | fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_seq_go.out ? 1'd1;
signal_reg.clk = clk;
signal_reg.reset = reset;
signal_reg.in = fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_par_go.out ? 1'd1;
signal_reg.in = fsm.out == 3'd0 & !signal_reg.out & wrapper_early_reset_static_seq_go.out ? 1'd1;
signal_reg.in = fsm.out == 3'd0 & signal_reg.out ? 1'd0;
early_reset_static_par_go.in = wrapper_early_reset_static_par_go.out ? 1'd1;
early_reset_static_seq_done.in = ud.out;
wrapper_early_reset_static_seq_go.in = go;
}
control {}
}
Loading