Skip to content

Commit

Permalink
Calyx-py switch-case optimization (#2413)
Browse files Browse the repository at this point in the history
Currently, Calyx-py supports "switch-case"-like semantics by having a
single par block with if blocks as arms.
i.e. if we wanted to have a switch-case statement like:
```
switch X {
  case 1 { do_case_one() }
  case 2 { do_case_two() }
}
```
then the generated Calyx would look sth like:
```
par {
  if eq_x_1.out { do_case_one; }
  if eq_x_2.out { do_case_two; }
}
```
But a more cycle-saving and "simple" way to implement case-like
semantics would be a series of nested if-else blocks. i.e:
```
if eq_x_1.out { do_case_one; }
else {
  if eq_x_2.out { do_case_two; }
}
```
This PR contains the above change! In our simple example test case
([`calyx-py/test/case-multi-component.py`](main...calyx-py-case-opt#diff-78bd598e9d32aa02e7751e77d85587ece1a4996eaa7fc04be824ac7b04fbe2ff))
the change helps save 3 out of the original 9 cycles.
  • Loading branch information
ayakayorihiro authored Feb 12, 2025
1 parent 10e9314 commit 08f426f
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 15 deletions.
9 changes: 5 additions & 4 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,19 @@ def case(
Branches are implemented via mutually exclusive `if` statements in the
component's `control` block."""
width = self.infer_width(signal)
ifs = []
for branch, controllable in cases.items():
curr_case = ast.Empty()
for branch, controllable in reversed(cases.items()):
prev_case = curr_case
std_eq = self.eq(
width, self.generate_name(f"{signal.name}_eq_{branch}"), signed
)

with self.continuous:
std_eq.left = signal
std_eq.right = const(width, branch)
ifs.append(if_(std_eq["out"], controllable))
curr_case = if_(std_eq["out"], controllable, prev_case)

return par(*ifs)
return curr_case

def port_width(self, port: ExprBuilder) -> int:
"""Get the width of an expression, which may be a port of this component."""
Expand Down
85 changes: 85 additions & 0 deletions calyx-py/test/case-multi-component.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import "primitives/core.futil";
import "primitives/memories/comb.futil";
import "primitives/binary_operators.futil";
component main() -> () {
cells {
@external mem = comb_mem_d1(32, 1, 1);
r = std_reg(32);
ans = std_reg(32);
id_1 = identity();
id_2 = identity();
id_3 = identity();
id_4 = identity();
id_5 = identity();
r_out_eq_4_1 = std_eq(32);
r_out_eq_3_2 = std_eq(32);
r_out_eq_2_3 = std_eq(32);
r_out_eq_1_4 = std_eq(32);
r_out_eq_0_5 = std_eq(32);
}
wires {
group read {
mem.addr0 = 1'd0;
r.in = mem.read_data;
r.write_en = 1'd1;
read[done] = r.done;
}
group write {
mem.addr0 = 1'd0;
mem.write_en = 1'd1;
mem.write_data = r.out;
write[done] = mem.done;
}
r_out_eq_4_1.left = r.out;
r_out_eq_4_1.right = 32'd4;
r_out_eq_3_2.left = r.out;
r_out_eq_3_2.right = 32'd3;
r_out_eq_2_3.left = r.out;
r_out_eq_2_3.right = 32'd2;
r_out_eq_1_4.left = r.out;
r_out_eq_1_4.right = 32'd1;
r_out_eq_0_5.left = r.out;
r_out_eq_0_5.right = 32'd0;
}
control {
seq {
read;
if r_out_eq_0_5.out {
invoke id_1(in_1=r.out)(out=ans.in);
} else {
if r_out_eq_1_4.out {
invoke id_2(in_1=r.out)(out=ans.in);
} else {
if r_out_eq_2_3.out {
invoke id_3(in_1=r.out)(out=ans.in);
} else {
if r_out_eq_3_2.out {
invoke id_4(in_1=r.out)(out=ans.in);
} else {
if r_out_eq_4_1.out {
invoke id_5(in_1=r.out)(out=ans.in);
}
}
}
}
}
write;
}
}
}
component identity(in_1: 32) -> (out: 32) {
cells {
r = std_reg(32);
}
wires {
group save {
r.in = in_1;
r.write_en = 1'd1;
save[done] = r.done;
}
out = r.out;
}
control {
save;
}
}
66 changes: 66 additions & 0 deletions calyx-py/test/case-multi-component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import calyx.builder as cb


def insert_identity_component(prog):
identity = prog.component("identity")
r = identity.reg(32, "r")
in_1 = identity.input("in_1", 32)
identity.output("out", 32)

with identity.group("save") as save:
r.in_ = in_1
r.write_en = cb.HI
save.done = r.done

with identity.continuous:
identity.this().out = r.out

identity.control += save

return identity


def make_program(prog):
main = prog.component("main")
mem = main.comb_mem_d1("mem", 32, 1, 1, is_external=True)
reg = main.reg(32, "r")
ans = main.reg(32, "ans")
id_component = insert_identity_component(prog)
# make 5 versions of ident
num_ident = 5
ids = []
for i in range(1, 1 + num_ident):
ids.append(main.cell(f"id_{i}", id_component))

# group to read from the memory
with main.group("read") as read:
mem.addr0 = cb.LO
reg.in_ = mem.read_data
reg.write_en = cb.HI
read.done = reg.done

with main.group("write") as write:
mem.addr0 = cb.LO
mem.write_en = cb.HI
mem.write_data = reg.out
write.done = mem.done

main.control += read
main.control += main.case(
reg.out,
{
n: cb.invoke(ids[n], in_in_1=reg.out, out_out=ans.in_)
for n in range(num_ident)
},
)
main.control += write


def build():
prog = cb.Builder()
make_program(prog)
return prog.program


if __name__ == "__main__":
build().emit()
21 changes: 10 additions & 11 deletions calyx-py/test/case.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@ import "primitives/binary_operators.futil";
component my_comp(in_1: 8) -> (out_1: 16) {
cells {
comp_reg = std_reg(1);
in_1_eq_1_1 = std_eq(8);
in_1_eq_2_2 = std_eq(8);
in_1_eq_2_1 = std_eq(8);
in_1_eq_1_2 = std_eq(8);
}
wires {
group my_group {

}
in_1_eq_1_1.left = in_1;
in_1_eq_1_1.right = 8'd1;
in_1_eq_2_2.left = in_1;
in_1_eq_2_2.right = 8'd2;
in_1_eq_2_1.left = in_1;
in_1_eq_2_1.right = 8'd2;
in_1_eq_1_2.left = in_1;
in_1_eq_1_2.right = 8'd1;
}
control {
par {
if in_1_eq_1_1.out {
my_group;
}
if in_1_eq_2_2.out {
if in_1_eq_1_2.out {
my_group;
} else {
if in_1_eq_2_1.out {
invoke comp_reg(in=1'd1)();
}
}
Expand Down

0 comments on commit 08f426f

Please sign in to comment.