Skip to content

Commit

Permalink
[spv-in] Convert conditional backedges to break if.
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb committed Apr 27, 2023
1 parent 9befaed commit a006953
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 16 deletions.
8 changes: 6 additions & 2 deletions src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,19 @@ impl<'function> BlockContext<'function> {
crate::Span::default(),
)
}
super::BodyFragment::Loop { body, continuing } => {
super::BodyFragment::Loop {
body,
continuing,
break_if,
} => {
let body = lower_impl(blocks, bodies, body);
let continuing = lower_impl(blocks, bodies, continuing);

block.push(
crate::Statement::Loop {
body,
continuing,
break_if: None,
break_if,
},
crate::Span::default(),
)
Expand Down
115 changes: 101 additions & 14 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ enum BodyFragment {
Loop {
body: BodyIndex,
continuing: BodyIndex,
break_if: Option<Handle<crate::Expression>>,
},
Switch {
selector: Handle<crate::Expression>,
Expand Down Expand Up @@ -429,7 +430,7 @@ struct PhiExpression {
expressions: Vec<(spirv::Word, spirv::Word)>,
}

#[derive(Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum MergeBlockInformation {
LoopMerge,
LoopContinue,
Expand Down Expand Up @@ -3114,30 +3115,120 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
get_expr_handle!(condition_id, lexp)
};

let true_id = self.next()?;
let false_id = self.next()?;

let merge_info_for_true_id = ctx.mergers.get(&true_id).copied();
let merge_info_for_false_id = ctx.mergers.get(&false_id).copied();

// Consume branch weights
for _ in 4..inst.wc {
let _ = self.next()?;
}

// Handle `OpBranchConditional`s used at the end of a loop
// body's "continuing" section as a "conditional backedge",
// i.e. a `do`-`while` condition, or `break if` in WGSL.

// HACK(eddyb) this has to go to the parent *twice*, because
// `OpLoopMerge` left the "continuing" section nested in the
// loop body in terms of `parent`, but not `BodyFragment`.
let parent_body_idx = ctx.bodies[body_idx].parent;
let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent;
match ctx.bodies[parent_parent_body_idx].data[..] {
[.., BodyFragment::Loop {
body: loop_body_idx,
continuing: loop_continuing_idx,
break_if: ref mut break_if_slot @ None,
}] if body_idx == loop_continuing_idx => {
// HACK(eddyb) Naga doesn't seem to have this helper,
// so it's declared on the fly here for convenience.
#[derive(Copy, Clone)]
struct BranchTarget {
label_id: u32,
merge_info: Option<MergeBlockInformation>,
}
let true_target = BranchTarget {
label_id: true_id,
merge_info: merge_info_for_true_id,
};
let false_target = BranchTarget {
label_id: false_id,
merge_info: merge_info_for_false_id,
};

// Try both orderings of break-vs-backedge, because
// SPIR-V is symmetrical here, unlike WGSL `break if`.
let break_if_cond = [true, false].into_iter().find_map(|true_breaks| {
let (break_candidate, backedge_candidate) = if true_breaks {
(true_target, false_target)
} else {
(false_target, true_target)
};

if break_candidate.merge_info
!= Some(MergeBlockInformation::LoopMerge)
{
return None;
}

// HACK(eddyb) since Naga doesn't explicitly track
// backedges, this is checking for the outcome of
// `OpLoopMerge` below (even if it looks weird).
let backedge_candidate_is_backedge =
backedge_candidate.merge_info.is_none()
&& ctx.body_for_label.get(&backedge_candidate.label_id)
== Some(&loop_body_idx);
if !backedge_candidate_is_backedge {
return None;
}

Some(if true_breaks {
condition
} else {
ctx.expressions.append(
crate::Expression::Unary {
op: crate::UnaryOperator::Not,
expr: condition,
},
span,
)
})
});

if let Some(break_if_cond) = break_if_cond {
*break_if_slot = Some(break_if_cond);

// This `OpBranchConditional` ends the "continuing"
// section of the loop body as normal, with the
// `break if` condition having been stashed above.
break None;
}
}
_ => {}
}

block.extend(emitter.finish(ctx.expressions));
ctx.blocks.insert(block_id, block);
let body = &mut ctx.bodies[body_idx];
body.data.push(BodyFragment::BlockId(block_id));

let true_id = self.next()?;
let false_id = self.next()?;

let same_target = true_id == false_id;

// Start a body block for the `accept` branch.
let accept = ctx.bodies.len();
let mut accept_block = Body::with_parent(body_idx);

// If the `OpBranchConditional`target is somebody else's
// If the `OpBranchConditional` target is somebody else's
// merge or continue block, then put a `Break` or `Continue`
// statement in this new body block.
if let Some(info) = ctx.mergers.get(&true_id) {
if let Some(info) = merge_info_for_true_id {
merger(
match same_target {
true => &mut ctx.bodies[body_idx],
false => &mut accept_block,
},
info,
&info,
)
} else {
// Note the body index for the block we're branching to.
Expand All @@ -3161,8 +3252,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let reject = ctx.bodies.len();
let mut reject_block = Body::with_parent(body_idx);

if let Some(info) = ctx.mergers.get(&false_id) {
merger(&mut reject_block, info)
if let Some(info) = merge_info_for_false_id {
merger(&mut reject_block, &info)
} else {
let prev = ctx.body_for_label.insert(false_id, reject);
debug_assert!(prev.is_none());
Expand All @@ -3177,11 +3268,6 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
reject,
});

// Consume branch weights
for _ in 4..inst.wc {
let _ = self.next()?;
}

return Ok(());
}
Op::Switch => {
Expand Down Expand Up @@ -3351,6 +3437,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
parent_body.data.push(BodyFragment::Loop {
body: loop_body_idx,
continuing: continue_idx,
break_if: None,
});
body_idx = loop_body_idx;
}
Expand Down
Binary file added tests/in/spv/do-while.spv
Binary file not shown.
64 changes: 64 additions & 0 deletions tests/in/spv/do-while.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
;; Ensure that `do`-`while`-style loops, with conditional backedges, are properly
;; supported, via `break if` (as `continuing { ... if c { break; } }` is illegal).
;;
;; The SPIR-V below was compiled from this GLSL fragment shader:
;; ```glsl
;; #version 450
;;
;; void f(bool cond) {
;; do {} while(cond);
;; }
;;
;; void main() {
;; f(true);
;; }
;; ```

OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 450
OpName %main "main"
OpName %f_b1_ "f(b1;"
OpName %cond "cond"
OpName %param "param"
%void = OpTypeVoid
%3 = OpTypeFunction %void
%bool = OpTypeBool
%_ptr_Function_bool = OpTypePointer Function %bool
%8 = OpTypeFunction %void %_ptr_Function_bool
%true = OpConstantTrue %bool

%main = OpFunction %void None %3
%5 = OpLabel
%param = OpVariable %_ptr_Function_bool Function
OpStore %param %true
%19 = OpFunctionCall %void %f_b1_ %param
OpReturn
OpFunctionEnd

%f_b1_ = OpFunction %void None %8
%cond = OpFunctionParameter %_ptr_Function_bool

%11 = OpLabel
OpBranch %12

%12 = OpLabel
OpLoopMerge %14 %15 None
OpBranch %13

%13 = OpLabel
OpBranch %15

;; This is the "continuing" block, and it contains a conditional branch between
;; the backedge (back to the loop header) and the loop merge ("break") target.
%15 = OpLabel
%16 = OpLoad %bool %cond
OpBranchConditional %16 %12 %14

%14 = OpLabel
OpReturn

OpFunctionEnd
33 changes: 33 additions & 0 deletions tests/out/glsl/do-while.main.Fragment.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#version 310 es

precision highp float;
precision highp int;


void fb1_(inout bool cond) {
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _e6 = cond;
bool unnamed = !(_e6);
if (unnamed) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1() {
bool param = false;
param = true;
fb1_(param);
return;
}

void main() {
main_1();
}

31 changes: 31 additions & 0 deletions tests/out/hlsl/do-while.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

void fb1_(inout bool cond)
{
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _expr6 = cond;
bool unnamed = !(_expr6);
if (unnamed) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1()
{
bool param = (bool)0;

param = true;
fb1_(param);
return;
}

void main()
{
main_1();
}
3 changes: 3 additions & 0 deletions tests/out/hlsl/do-while.hlsl.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
vertex=()
fragment=(main:ps_5_1 )
compute=()
37 changes: 37 additions & 0 deletions tests/out/msl/do-while.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


void fb1_(
thread bool& cond
) {
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _e6 = cond;
bool unnamed = !(_e6);
if (!(cond)) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1(
) {
bool param = {};
param = true;
fb1_(param);
return;
}

fragment void main_(
) {
main_1();
}
24 changes: 24 additions & 0 deletions tests/out/wgsl/do-while.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
fn fb1_(cond: ptr<function, bool>) {
loop {
continue;
continuing {
let _e6 = (*cond);
_ = !(_e6);
break if !(_e6);
}
}
return;
}

fn main_1() {
var param: bool;

param = true;
fb1_((&param));
return;
}

@fragment
fn main() {
main_1();
}
Loading

0 comments on commit a006953

Please sign in to comment.