Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga wgsl-in] Apply automatic conversions to values being assigned. #6822

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion naga/src/front/wgsl/lower/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<'source> super::ExpressionContext<'source, '_, '_> {
}))
};

let expr_scalar = match expr_inner.scalar() {
let expr_scalar = match expr_inner.automatically_convertible_scalar(&self.module.types) {
Some(scalar) => scalar,
None => return Err(make_error()),
};
Expand Down Expand Up @@ -436,6 +436,32 @@ impl crate::TypeInner {
| Ti::BindingArray { .. } => None,
}
}

/// Return the leaf scalar type of `pointer`.
///
/// `pointer` must be a `TypeInner` representing a pointer type.
pub fn pointer_automatically_convertible_scalar(
&self,
types: &crate::UniqueArena<crate::Type>,
) -> Option<crate::Scalar> {
use crate::TypeInner as Ti;
match *self {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => {
Some(scalar)
}
Ti::Atomic(_) => None,
Ti::Pointer { base, .. } | Ti::Array { base, .. } => {
types[base].inner.automatically_convertible_scalar(types)
}
Ti::ValuePointer { scalar, .. } => Some(scalar),
Ti::Struct { .. }
| Ti::Image { .. }
| Ti::Sampler { .. }
| Ti::AccelerationStructure
| Ti::RayQuery
| Ti::BindingArray { .. } => None,
}
}
}

impl crate::Scalar {
Expand Down
41 changes: 29 additions & 12 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1700,31 +1700,48 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
} => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
let target_span = ctx.ast_expressions.get_span(ast_target);

let target = self.expression_for_reference(
ast_target,
&mut ctx.as_expression(block, &mut emitter),
)?;
let mut value =
self.expression(value, &mut ctx.as_expression(block, &mut emitter))?;

let mut ectx = ctx.as_expression(block, &mut emitter);
let target = self.expression_for_reference(ast_target, &mut ectx)?;
let target_handle = match target {
Typed::Reference(handle) => handle,
Typed::Plain(handle) => {
let ty = ctx.invalid_assignment_type(handle);
return Err(Error::InvalidAssignment {
span: ctx.ast_expressions.get_span(ast_target),
span: target_span,
ty,
});
}
};

// Usually the value needs to be converted to match the type of
// the memory view you're assigning it to. The bit shift
// operators are exceptions, in that the right operand is always
// a `u32` or `vecN<u32>`.
let target_scalar = match op {
Some(crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight) => {
Some(crate::Scalar::U32)
}
_ => resolve_inner!(ectx, target_handle)
.pointer_automatically_convertible_scalar(&ectx.module.types),
};

let value = self.expression_for_abstract(value, &mut ectx)?;
let mut value = match target_scalar {
Some(target_scalar) => ectx.try_automatic_conversion_for_leaf_scalar(
value,
target_scalar,
target_span,
)?,
None => value,
};

let value = match op {
Some(op) => {
let mut ctx = ctx.as_expression(block, &mut emitter);
let mut left = ctx.apply_load_rule(target)?;
ctx.binary_op_splat(op, &mut left, &mut value)?;
ctx.append_expression(
let mut left = ectx.apply_load_rule(target)?;
ectx.binary_op_splat(op, &mut left, &mut value)?;
ectx.append_expression(
crate::Expression::Binary {
op,
left,
Expand Down
81 changes: 79 additions & 2 deletions naga/tests/in/abstract-types-var.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var<private> ivfs_af = vec2<f32>(1.0);
var<private> iafafaf = array<f32, 2>(1.0, 2.0);
var<private> iafaiai = array<f32, 2>(1, 2);

var<private> iaipaiai = array(1, 2);
var<private> iafpafaf = array(1.0, 2.0);
var<private> iafpaiaf = array(1, 2.0);
var<private> iafpafai = array(1.0, 2);
Expand Down Expand Up @@ -93,13 +94,63 @@ fn all_constant_arguments() {
var xai_iai: array<i32, 2> = array<i32, 2>(1i, 2);
var xaiai_i: array<i32, 2> = array<i32, 2>(1, 2i);

// Ideally these would infer the var type from the initializer,
// but we don't support that yet.
var xaipaiai: array<i32, 2> = array(1, 2);
var xafpaiai: array<f32, 2> = array(1, 2);
var xafpaiaf: array<f32, 2> = array(1, 2.0);
var xafpafai: array<f32, 2> = array(1.0, 2);
var xafpafaf: array<f32, 2> = array(1.0, 2.0);

var iaipaiai = array(1, 2);
var iafpaiaf = array(1, 2.0);
var iafpafai = array(1.0, 2);
var iafpafaf = array(1.0, 2.0);

// Assignments to all of the above.
xvipaiai = vec2(42, 43);
xvupaiai = vec2(44, 45);
xvfpaiai = vec2(46, 47);

xvupuai = vec2(42u, 43);
xvupaiu = vec2(42, 43u);

xvuuai = vec2<u32>(42u, 43);
xvuaiu = vec2<u32>(42, 43u);

xmfpaiaiaiai = mat2x2(1, 2, 3, 4);
xmfpafaiaiai = mat2x2(1.0, 2, 3, 4);
xmfpaiafaiai = mat2x2(1, 2.0, 3, 4);
xmfpaiaiafai = mat2x2(1, 2, 3.0, 4);
xmfpaiaiaiaf = mat2x2(1, 2, 3, 4.0);

xmfp_faiaiai = mat2x2(1.0f, 2, 3, 4);
xmfpai_faiai = mat2x2(1, 2.0f, 3, 4);
xmfpaiai_fai = mat2x2(1, 2, 3.0f, 4);
xmfpaiaiai_f = mat2x2(1, 2, 3, 4.0f);

xvispai = vec2(1);
xvfspaf = vec2(1.0);
xvis_ai = vec2<i32>(1);
xvus_ai = vec2<u32>(1);
xvfs_ai = vec2<f32>(1);
xvfs_af = vec2<f32>(1.0);

xafafaf = array<f32, 2>(1.0, 2.0);
xaf_faf = array<f32, 2>(1.0f, 2.0);
xafaf_f = array<f32, 2>(1.0, 2.0f);
xafaiai = array<f32, 2>(1, 2);
xai_iai = array<i32, 2>(1i, 2);
xaiai_i = array<i32, 2>(1, 2i);

xaipaiai = array(1, 2);
xafpaiai = array(1, 2);
xafpaiaf = array(1, 2.0);
xafpafai = array(1.0, 2);
xafpafaf = array(1.0, 2.0);

iaipaiai = array(1, 2);
iafpaiaf = array(1, 2.0);
iafpafai = array(1.0, 2);
iafpafaf = array(1.0, 2.0);
}

fn mixed_constant_and_runtime_arguments() {
Expand Down Expand Up @@ -131,4 +182,30 @@ fn mixed_constant_and_runtime_arguments() {
var xafpai_f: array<f32, 2> = array(1, f);
var xaip_iai: array<i32, 2> = array(i, 2);
var xaipai_i: array<i32, 2> = array(1, i);

// Assignments to all of the above.
xvupuai = vec2(u, 43);
xvupaiu = vec2(42, u);

xvuuai = vec2<u32>(u, 43);
xvuaiu = vec2<u32>(42, u);

xmfp_faiaiai = mat2x2(f, 2, 3, 4);
xmfpai_faiai = mat2x2(1, f, 3, 4);
xmfpaiai_fai = mat2x2(1, 2, f, 4);
xmfpaiaiai_f = mat2x2(1, 2, 3, f);

xaf_faf = array<f32, 2>(f, 2.0);
xafaf_f = array<f32, 2>(1.0, f);
xaf_fai = array<f32, 2>(f, 2);
xafai_f = array<f32, 2>(1, f);
xai_iai = array<i32, 2>(i, 2);
xaiai_i = array<i32, 2>(1, i);

xafp_faf = array(f, 2.0);
xafpaf_f = array(1.0, f);
xafp_fai = array(f, 2);
xafpai_f = array(1, f);
xaip_iai = array(i, 2);
xaipai_i = array(1, i);
}
82 changes: 82 additions & 0 deletions naga/tests/out/msl/abstract-types-var.msl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,48 @@ void all_constant_arguments(
type_7 xafpaiaf = type_7 {1.0, 2.0};
type_7 xafpafai = type_7 {1.0, 2.0};
type_7 xafpafaf = type_7 {1.0, 2.0};
type_8 iaipaiai = type_8 {1, 2};
type_7 iafpaiaf = type_7 {1.0, 2.0};
type_7 iafpafai = type_7 {1.0, 2.0};
type_7 iafpafaf = type_7 {1.0, 2.0};
xvipaiai = metal::int2(42, 43);
xvupaiai = metal::uint2(44u, 45u);
xvfpaiai = metal::float2(46.0, 47.0);
xvupuai = metal::uint2(42u, 43u);
xvupaiu = metal::uint2(42u, 43u);
xvuuai = metal::uint2(42u, 43u);
xvuaiu = metal::uint2(42u, 43u);
xmfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfp_faiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfpai_faiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfpaiai_fai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xmfpaiaiai_f = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
xvispai = metal::int2(1);
xvfspaf = metal::float2(1.0);
xvis_ai = metal::int2(1);
xvus_ai = metal::uint2(1u);
xvfs_ai = metal::float2(1.0);
xvfs_af = metal::float2(1.0);
xafafaf = type_7 {1.0, 2.0};
xaf_faf = type_7 {1.0, 2.0};
xafaf_f = type_7 {1.0, 2.0};
xafaiai = type_7 {1.0, 2.0};
xai_iai = type_8 {1, 2};
xaiai_i = type_8 {1, 2};
xaipaiai = type_8 {1, 2};
xafpaiai = type_7 {1.0, 2.0};
xafpaiaf = type_7 {1.0, 2.0};
xafpafai = type_7 {1.0, 2.0};
xafpafaf = type_7 {1.0, 2.0};
iaipaiai = type_8 {1, 2};
iafpaiaf = type_7 {1.0, 2.0};
iafpafai = type_7 {1.0, 2.0};
iafpafaf = type_7 {1.0, 2.0};
return;
}

void mixed_constant_and_runtime_arguments(
Expand Down Expand Up @@ -113,5 +155,45 @@ void mixed_constant_and_runtime_arguments(
xaip_iai = type_8 {_e91, 2};
int _e95 = i;
xaipai_i = type_8 {1, _e95};
uint _e99 = u;
xvupuai_1 = metal::uint2(_e99, 43u);
uint _e102 = u;
xvupaiu_1 = metal::uint2(42u, _e102);
uint _e105 = u;
xvuuai_1 = metal::uint2(_e105, 43u);
uint _e108 = u;
xvuaiu_1 = metal::uint2(42u, _e108);
float _e111 = f;
xmfp_faiaiai_1 = metal::float2x2(metal::float2(_e111, 2.0), metal::float2(3.0, 4.0));
float _e118 = f;
xmfpai_faiai_1 = metal::float2x2(metal::float2(1.0, _e118), metal::float2(3.0, 4.0));
float _e125 = f;
xmfpaiai_fai_1 = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(_e125, 4.0));
float _e132 = f;
xmfpaiaiai_f_1 = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, _e132));
float _e139 = f;
xaf_faf_1 = type_7 {_e139, 2.0};
float _e142 = f;
xafaf_f_1 = type_7 {1.0, _e142};
float _e145 = f;
xaf_fai = type_7 {_e145, 2.0};
float _e148 = f;
xafai_f = type_7 {1.0, _e148};
int _e151 = i;
xai_iai_1 = type_8 {_e151, 2};
int _e154 = i;
xaiai_i_1 = type_8 {1, _e154};
float _e157 = f;
xafp_faf = type_7 {_e157, 2.0};
float _e160 = f;
xafpaf_f = type_7 {1.0, _e160};
float _e163 = f;
xafp_fai = type_7 {_e163, 2.0};
float _e166 = f;
xafpai_f = type_7 {1.0, _e166};
int _e169 = i;
xaip_iai = type_8 {_e169, 2};
int _e172 = i;
xaipai_i = type_8 {1, _e172};
return;
}
Loading
Loading