Skip to content

Commit

Permalink
[hlsl-out] fix fallthrough in switch statements
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy authored and jimblandy committed May 30, 2022
1 parent 7c7e962 commit 91ee407
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 130 deletions.
46 changes: 28 additions & 18 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
let indent_level_1 = level.next();
let indent_level_2 = indent_level_1.next();

for case in cases {
for (i, case) in cases.iter().enumerate() {
match case.value {
crate::SwitchValue::Integer(value) => writeln!(
self.out,
Expand All @@ -1663,25 +1663,35 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

// FXC doesn't support fallthrough so we duplicate the body of the following case blocks
if case.fall_through {
// Generate each fallthrough case statement in a new block. This is done to
// prevent symbol collision of variables declared in these cases statements.
writeln!(self.out, "{}/* fallthrough */", indent_level_2)?;
writeln!(self.out, "{}{{", indent_level_2)?;
}
for sta in case.body.iter() {
self.write_stmt(
module,
sta,
func_ctx,
back::Level(indent_level_2.0 + usize::from(case.fall_through)),
)?;
}
let curr_len = i + 1;
let end_case_idx = curr_len
+ cases
.iter()
.skip(curr_len)
.position(|case| !case.fall_through)
.unwrap();
let indent_level_3 = indent_level_2.next();
for case in &cases[i..=end_case_idx] {
writeln!(self.out, "{}{{", indent_level_2)?;
for sta in case.body.iter() {
self.write_stmt(module, sta, func_ctx, indent_level_3)?;
}
writeln!(self.out, "{}}}", indent_level_2)?;
}

if case.fall_through {
writeln!(self.out, "{}}}", indent_level_2)?;
} else if case.body.last().map_or(true, |s| !s.is_terminator()) {
writeln!(self.out, "{}break;", indent_level_2)?;
let last_case = &cases[end_case_idx];
if last_case.body.last().map_or(true, |s| !s.is_terminator()) {
writeln!(self.out, "{}break;", indent_level_2)?;
}
} else {
for sta in case.body.iter() {
self.write_stmt(module, sta, func_ctx, indent_level_2)?;
}
if case.body.last().map_or(true, |s| !s.is_terminator()) {
writeln!(self.out, "{}break;", indent_level_2)?;
}
}

writeln!(self.out, "{}}}", indent_level_1)?;
Expand Down
7 changes: 5 additions & 2 deletions tests/in/control-flow.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
pos = 2;
fallthrough;
}
case 4: {}
default: {
case 4: {
pos = 3;
fallthrough;
}
default: {
pos = 4;
}
}

Expand Down
9 changes: 5 additions & 4 deletions tests/out/glsl/control-flow.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ void main() {
pos = 2;
/* fallthrough */
case 4:
break;
default:
pos = 3;
/* fallthrough */
default:
pos = 4;
break;
}
switch(0u) {
Expand All @@ -70,8 +71,8 @@ void main() {
default:
break;
}
int _e10 = pos;
switch(_e10) {
int _e11 = pos;
switch(_e11) {
case 1:
pos = 0;
break;
Expand Down
24 changes: 19 additions & 5 deletions tests/out/hlsl/control-flow.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,28 @@ void main(uint3 global_id : SV_DispatchThreadID)
break;
}
case 3: {
/* fallthrough */
{
pos = 2;
}
{
pos = 3;
}
{
pos = 4;
}
break;
}
case 4: {
{
pos = 3;
}
{
pos = 4;
}
break;
}
default: {
pos = 3;
pos = 4;
break;
}
}
Expand All @@ -81,8 +93,8 @@ void main(uint3 global_id : SV_DispatchThreadID)
break;
}
}
int _expr10 = pos;
switch(_expr10) {
int _expr11 = pos;
switch(_expr11) {
case 1: {
pos = 0;
break;
Expand All @@ -92,10 +104,12 @@ void main(uint3 global_id : SV_DispatchThreadID)
return;
}
case 3: {
/* fallthrough */
{
pos = 2;
}
{
return;
}
}
case 4: {
return;
Expand Down
8 changes: 4 additions & 4 deletions tests/out/msl/control-flow.msl
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ kernel void main_(
pos = 2;
}
case 4: {
break;
pos = 3;
}
default: {
pos = 3;
pos = 4;
break;
}
}
Expand All @@ -87,8 +87,8 @@ kernel void main_(
break;
}
}
int _e10 = pos;
switch(_e10) {
int _e11 = pos;
switch(_e11) {
case 1: {
pos = 0;
break;
Expand Down
Loading

0 comments on commit 91ee407

Please sign in to comment.