Skip to content

Commit

Permalink
[naga wgsl-in] Support abstract operands to binary operators.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy committed Dec 13, 2023
1 parent f2828ac commit c4b4387
Show file tree
Hide file tree
Showing 11 changed files with 498 additions and 28 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ This feature allowed you to call `global_id` on any wgpu opaque handle to get a

#### Naga

- Naga'sn WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).
- Naga's WGSL front end now allows binary operators to produce values with abstract types, rather than concretizing thir operands. By @jimblandy in [#4850](https://github.com/gfx-rs/wgpu/pull/4850).

- Naga's WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).

- Naga constant evaluation can now process binary operators whose operands are both vectors. By @jimblandy in [#4861](https://github.com/gfx-rs/wgpu/pull/4861).

### Changes
Expand Down
20 changes: 20 additions & 0 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ pub enum Error<'a> {
source_span: Span,
source_type: String,
},
AutoConversionLeafScalar {
dest_span: Span,
dest_scalar: String,
source_span: Span,
source_type: String,
},
ConcretizationFailed {
expr_span: Span,
expr_type: String,
Expand Down Expand Up @@ -738,6 +744,20 @@ impl<'a> Error<'a> {
],
notes: vec![],
},
Error::AutoConversionLeafScalar { dest_span, ref dest_scalar, source_span, ref source_type } => ParseError {
message: format!("automatic conversions cannot convert elements of `{source_type}` to `{dest_scalar}`"),
labels: vec![
(
dest_span,
format!("a value with elements of type {dest_scalar} is required here").into(),
),
(
source_span,
format!("this expression has type {source_type}").into(),
)
],
notes: vec![],
},
Error::ConcretizationFailed { expr_span, ref expr_type, ref scalar, ref inner } => ParseError {
message: format!("failed to convert expression to a concrete type: {}", inner),
labels: vec![
Expand Down
80 changes: 72 additions & 8 deletions naga/src/front/wgsl/lower/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,80 @@ impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> {
}
};

let converted = if let crate::TypeInner::Array { .. } = *goal_inner {
let span = self.get_expression_span(expr);
self.convert_leaf_scalar(expr, expr_span, goal_scalar)
}

/// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions.
///
/// If no conversions are necessary, return `expr` unchanged.
///
/// If automatic conversions cannot convert `expr` to `goal_scalar`, return
/// an [`AutoConversionLeafScalar`] error.
///
/// Although the Load Rule is one of the automatic conversions, this
/// function assumes it has already been applied if appropriate, as
/// indicated by the fact that the Rust type of `expr` is not `Typed<_>`.
///
/// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar
pub fn try_automatic_conversion_for_leaf_scalar(
&mut self,
expr: Handle<crate::Expression>,
goal_scalar: crate::Scalar,
goal_span: Span,
) -> Result<Handle<crate::Expression>, super::Error<'source>> {
let expr_span = self.get_expression_span(expr);
let expr_resolution = super::resolve!(self, expr);
let types = &self.module.types;
let expr_inner = expr_resolution.inner_with(types);

let make_error = || {
let gctx = &self.module.to_ctx();
let source_type = expr_resolution.to_wgsl(gctx);
super::Error::AutoConversionLeafScalar {
dest_span: goal_span,
dest_scalar: goal_scalar.to_wgsl(),
source_span: expr_span,
source_type,
}
};

let expr_scalar = match expr_inner.scalar() {
Some(scalar) => scalar,
None => return Err(make_error()),
};

if expr_scalar == goal_scalar {
return Ok(expr);
}

if !expr_scalar.automatically_converts_to(goal_scalar) {
return Err(make_error());
}

assert!(expr_scalar.is_abstract());

self.convert_leaf_scalar(expr, expr_span, goal_scalar)
}

fn convert_leaf_scalar(
&mut self,
expr: Handle<crate::Expression>,
expr_span: Span,
goal_scalar: crate::Scalar,
) -> Result<Handle<crate::Expression>, super::Error<'source>> {
let expr_inner = super::resolve_inner!(self, expr);
if let crate::TypeInner::Array { .. } = *expr_inner {
self.as_const_evaluator()
.cast_array(expr, goal_scalar, span)
.map_err(|err| super::Error::ConstantEvaluatorError(err, span))?
.cast_array(expr, goal_scalar, expr_span)
.map_err(|err| super::Error::ConstantEvaluatorError(err, expr_span))
} else {
let cast = crate::Expression::As {
expr,
kind: goal_scalar.kind,
convert: Some(goal_scalar.width),
};
self.append_expression(cast, expr_span)?
};

Ok(converted)
self.append_expression(cast, expr_span)
}
}

/// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions.
Expand Down Expand Up @@ -428,6 +487,11 @@ impl crate::Scalar {
}
}

/// Return `true` if automatic conversions will covert `self` to `goal`.
pub fn automatically_converts_to(self, goal: Self) -> bool {
self.automatic_conversion_combine(goal) == Some(goal)
}

const fn concretize(self) -> Self {
use crate::ScalarKind as Sk;
match self.kind {
Expand Down
52 changes: 47 additions & 5 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1602,11 +1602,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(Typed::Reference(pointer));
}
ast::Expression::Binary { op, left, right } => {
// Load both operands.
let mut left = self.expression(left, ctx)?;
let mut right = self.expression(right, ctx)?;
ctx.binary_op_splat(op, &mut left, &mut right)?;
Typed::Plain(crate::Expression::Binary { op, left, right })
self.binary(op, left, right, span, ctx)?
}
ast::Expression::Call {
ref function,
Expand Down Expand Up @@ -1737,6 +1733,52 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
expr.try_map(|handle| ctx.append_expression(handle, span))
}

fn binary(
&mut self,
op: crate::BinaryOperator,
left: Handle<ast::Expression<'source>>,
right: Handle<ast::Expression<'source>>,
span: Span,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Typed<crate::Expression>, Error<'source>> {
// Load both operands.
let mut left = self.expression_for_abstract(left, ctx)?;
let mut right = self.expression_for_abstract(right, ctx)?;

// Convert `scalar op vector` to `vector op vector` by introducing
// `Splat` expressions.
ctx.binary_op_splat(op, &mut left, &mut right)?;

// Apply automatic conversions.
match op {
// Shift operators require the right operand to be `u32` or
// `vecN<u32>`. We can let the validator sort out vector length
// issues, but the right operand must be, or convert to, a u32 leaf
// scalar.
crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => {
right =
ctx.try_automatic_conversion_for_leaf_scalar(right, crate::Scalar::U32, span)?;
}

// All other operators follow the same pattern: reconcile the
// scalar leaf types. If there's no reconciliation possible,
// leave the expressions as they are: validation will report the
// problem.
_ => {
ctx.grow_types(left)?;
ctx.grow_types(right)?;
if let Ok(consensus_scalar) =
ctx.automatic_conversion_consensus([left, right].iter())
{
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
}
}
}

Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
}

/// Generate Naga IR for call expressions and statements, and type
/// constructor expressions.
///
Expand Down
54 changes: 47 additions & 7 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub enum ConstantEvaluatorError {
InvalidAccessIndexTy,
#[error("Constants don't support array length expressions")]
ArrayLength,
#[error("Cannot cast type `{from}` to `{to}`")]
#[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
InvalidCastArg { from: String, to: String },
#[error("Cannot apply the unary op to the argument")]
InvalidUnaryOpArg,
Expand Down Expand Up @@ -989,15 +989,11 @@ impl<'a> ConstantEvaluator<'a> {
let expr = self.eval_zero_value(expr, span)?;

let make_error = || -> Result<_, ConstantEvaluatorError> {
let ty = self.resolve_type(expr)?;
let from = format!("{:?} {:?}", expr, self.expressions[expr]);

#[cfg(feature = "wgsl-in")]
let from = ty.to_wgsl(&self.to_ctx());
#[cfg(feature = "wgsl-in")]
let to = target.to_wgsl();

#[cfg(not(feature = "wgsl-in"))]
let from = format!("{ty:?}");
#[cfg(not(feature = "wgsl-in"))]
let to = format!("{target:?}");

Expand Down Expand Up @@ -1325,6 +1321,47 @@ impl<'a> ConstantEvaluator<'a> {
BinaryOperator::Modulo => a % b,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}),
(Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
Literal::AbstractInt(match op {
BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
ConstantEvaluatorError::Overflow("addition".into())
})?,
BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
ConstantEvaluatorError::Overflow("subtraction".into())
})?,
BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
ConstantEvaluatorError::Overflow("multiplication".into())
})?,
BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
if b == 0 {
ConstantEvaluatorError::DivisionByZero
} else {
ConstantEvaluatorError::Overflow("division".into())
}
})?,
BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
if b == 0 {
ConstantEvaluatorError::RemainderByZero
} else {
ConstantEvaluatorError::Overflow("remainder".into())
}
})?,
BinaryOperator::And => a & b,
BinaryOperator::ExclusiveOr => a ^ b,
BinaryOperator::InclusiveOr => a | b,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
})
}
(Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
Literal::AbstractFloat(match op {
BinaryOperator::Add => a + b,
BinaryOperator::Subtract => a - b,
BinaryOperator::Multiply => a * b,
BinaryOperator::Divide => a / b,
BinaryOperator::Modulo => a % b,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
})
}
(Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
BinaryOperator::LogicalAnd => a && b,
BinaryOperator::LogicalOr => a || b,
Expand Down Expand Up @@ -1550,7 +1587,10 @@ impl<'a> ConstantEvaluator<'a> {
};
Tr::Value(TypeInner::Vector { scalar, size })
}
_ => return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant),
_ => {
log::debug!("resolve_type: SubexpressionsAreNotConstant");
return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
}
};

Ok(resolution)
Expand Down
45 changes: 45 additions & 0 deletions naga/tests/in/abstract-types-operators.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
const plus_fafaf: f32 = 1.0 + 2.0;
const plus_fafai: f32 = 1.0 + 2;
const plus_faf_f: f32 = 1.0 + 2f;
const plus_faiaf: f32 = 1 + 2.0;
const plus_faiai: f32 = 1 + 2;
const plus_fai_f: f32 = 1 + 2f;
const plus_f_faf: f32 = 1f + 2.0;
const plus_f_fai: f32 = 1f + 2;
const plus_f_f_f: f32 = 1f + 2f;

const plus_iaiai: i32 = 1 + 2;
const plus_iai_i: i32 = 1 + 2i;
const plus_i_iai: i32 = 1i + 2;
const plus_i_i_i: i32 = 1i + 2i;

const plus_uaiai: u32 = 1 + 2;
const plus_uai_u: u32 = 1 + 2u;
const plus_u_uai: u32 = 1u + 2;
const plus_u_u_u: u32 = 1u + 2u;

fn runtime_values() {
var f: f32 = 42;
var i: i32 = 43;
var u: u32 = 44;

var plus_fafaf: f32 = 1.0 + 2.0;
var plus_fafai: f32 = 1.0 + 2;
var plus_faf_f: f32 = 1.0 + f;
var plus_faiaf: f32 = 1 + 2.0;
var plus_faiai: f32 = 1 + 2;
var plus_fai_f: f32 = 1 + f;
var plus_f_faf: f32 = f + 2.0;
var plus_f_fai: f32 = f + 2;
var plus_f_f_f: f32 = f + f;

var plus_iaiai: i32 = 1 + 2;
var plus_iai_i: i32 = 1 + i;
var plus_i_iai: i32 = i + 2;
var plus_i_i_i: i32 = i + i;

var plus_uaiai: u32 = 1 + 2;
var plus_uai_u: u32 = 1 + u;
var plus_u_uai: u32 = u + 2;
var plus_u_u_u: u32 = u + u;
}
Loading

0 comments on commit c4b4387

Please sign in to comment.