Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement override-expression evaluation in functions #5387

Merged
merged 12 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 366 additions & 21 deletions naga/src/back/pipeline_constants.rs

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions naga/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ impl Block {
self.span_info.splice(range.clone(), other.span_info);
self.body.splice(range, other.body);
}

pub fn span_into_iter(self) -> impl Iterator<Item = (Statement, Span)> {
let Block { body, span_info } = self;
body.into_iter().zip(span_info)
}

pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> {
let span_iter = self.span_info.iter();
self.body.iter().zip(span_iter)
Expand Down
2 changes: 1 addition & 1 deletion naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let init;
if let Some(init_ast) = v.init {
let mut ectx = ctx.as_const();
let mut ectx = ctx.as_override();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
Expand Down
75 changes: 46 additions & 29 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ enum Behavior<'a> {
Glsl(GlslRestrictions<'a>),
}

impl Behavior<'_> {
/// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
const fn has_runtime_restrictions(&self) -> bool {
matches!(
self,
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_))
)
}
}

/// A context for evaluating constant expressions.
///
/// A `ConstantEvaluator` points at an expression arena to which it can append
Expand Down Expand Up @@ -699,37 +710,43 @@ impl<'a> ConstantEvaluator<'a> {
expr: Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
match (
&self.behavior,
self.expression_kind_tracker.type_of_with_expr(&expr),
) {
// avoid errors on unimplemented functionality if possible
(
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
ExpressionKind::Const,
) => match self.try_eval_and_append_impl(&expr, span) {
Err(
ConstantEvaluatorError::NotImplemented(_)
| ConstantEvaluatorError::InvalidBinaryOpArgs,
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
res => res,
match self.expression_kind_tracker.type_of_with_expr(&expr) {
ExpressionKind::Const => {
let eval_result = self.try_eval_and_append_impl(&expr, span);
// We should be able to evaluate `Const` expressions at this
// point. If we failed to, then that probably means we just
// haven't implemented that part of constant evaluation. Work
// around this by simply emitting it as a run-time expression.
if self.behavior.has_runtime_restrictions()
&& matches!(
eval_result,
Err(ConstantEvaluatorError::NotImplemented(_)
| ConstantEvaluatorError::InvalidBinaryOpArgs,)
)
{
Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
} else {
eval_result
}
}
ExpressionKind::Override => match self.behavior {
Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
Ok(self.append_expr(expr, span, ExpressionKind::Override))
}
Behavior::Wgsl(WgslRestrictions::Const) => {
Err(ConstantEvaluatorError::OverrideExpr)
}
Behavior::Glsl(_) => {
unreachable!()
}
},
(_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span),
(&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => {
Err(ConstantEvaluatorError::OverrideExpr)
ExpressionKind::Runtime => {
if self.behavior.has_runtime_restrictions() {
Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
} else {
Err(ConstantEvaluatorError::RuntimeExpr)
}
}
(
&Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)),
ExpressionKind::Override,
) => Ok(self.append_expr(expr, span, ExpressionKind::Override)),
(&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(),
(
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
ExpressionKind::Runtime,
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
(_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr),
}
}

Expand Down
4 changes: 2 additions & 2 deletions naga/src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub enum GlobalVariableError {
Handle<crate::Type>,
#[source] Disalignment,
),
#[error("Initializer must be a const-expression")]
#[error("Initializer must be an override-expression")]
InitializerExprType,
#[error("Initializer doesn't match the variable type")]
InitializerType,
Expand Down Expand Up @@ -529,7 +529,7 @@ impl super::Validator {
}
}

if !global_expr_kind.is_const(init) {
if !global_expr_kind.is_const_or_override(init) {
return Err(GlobalVariableError::InitializerExprType);
}

Expand Down
9 changes: 9 additions & 0 deletions naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
(
spv: (
version: (1, 0),
separate_entry_points: true,
),
pipeline_constants: {
"o": 2.0
}
)
7 changes: 7 additions & 0 deletions naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
override o: i32;
var<workgroup> a: atomic<u32>;

@compute @workgroup_size(1)
fn f() {
atomicCompareExchangeWeak(&a, u32(o), 1u);
}
18 changes: 18 additions & 0 deletions naga/tests/in/overrides-ray-query.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
(
god_mode: true,
spv: (
version: (1, 4),
separate_entry_points: true,
),
msl: (
lang_version: (2, 4),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
zero_initialize_workgroup_memory: false,
per_entry_point_map: {},
inline_samplers: [],
),
pipeline_constants: {
"o": 2.0
}
)
21 changes: 21 additions & 0 deletions naga/tests/in/overrides-ray-query.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
override o: f32;

@group(0) @binding(0)
var acc_struct: acceleration_structure;

@compute @workgroup_size(1)
fn main() {
var rq: ray_query;

let desc = RayDesc(
RAY_FLAG_TERMINATE_ON_FIRST_HIT,
0xFFu,
o * 17.0,
o * 19.0,
vec3<f32>(o * 23.0),
vec3<f32>(o * 29.0, o * 31.0, o * 37.0),
);
rayQueryInitialize(&rq, acc_struct, desc);

while (rayQueryProceed(&rq)) {}
}
10 changes: 9 additions & 1 deletion naga/tests/in/overrides.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,13 @@

override inferred_f32 = 2.718;

var<private> gain_x_10: f32 = gain * 10.;

@compute @workgroup_size(1)
fn main() {}
fn main() {
var t = height * 5;
let a = !has_point_light;
var x = a;

var gain_x_100 = gain_x_10 * 10.;
}
148 changes: 146 additions & 2 deletions naga/tests/out/analysis/overrides.info.ron
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,143 @@
),
may_kill: false,
sampling_set: [],
global_uses: [],
expressions: [],
global_uses: [
("READ"),
],
expressions: [
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(2),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(4),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 2,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
non_uniform_result: Some(7),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 1,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
assignable_global: Some(1),
ty: Value(Pointer(
base: 2,
space: Private,
)),
),
(
uniformity: (
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(2),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(12),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 2,
space: Function,
)),
),
],
sampling: [],
dual_source_blending: false,
),
Expand Down Expand Up @@ -43,5 +178,14 @@
kind: Float,
width: 4,
))),
Handle(2),
Value(Scalar((
kind: Float,
width: 4,
))),
Value(Scalar((
kind: Float,
width: 4,
))),
],
)
10 changes: 10 additions & 0 deletions naga/tests/out/hlsl/overrides.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,18 @@ static const float depth = 2.3;
static const float height = 4.6;
static const float inferred_f32_ = 2.718;

static float gain_x_10_ = 11.0;

[numthreads(1, 1, 1)]
void main()
{
float t = (float)0;
bool x = (bool)0;
float gain_x_100_ = (float)0;

t = 23.0;
x = true;
float _expr10 = gain_x_10_;
gain_x_100_ = (_expr10 * 10.0);
return;
}
Loading
Loading