Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga wgsl-in] Automatic conversions for var initializers. #4755

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S
- Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673).
- Add more metal keywords. By @fornwall in [#4707](https://github.com/gfx-rs/wgpu/pull/4707).

- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743)).
- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743), [#4755](https://github.com/gfx-rs/wgpu/pull/4755)).

Abstract types make numeric literals easier to use, by
automatically converting literals and other constant expressions
Expand All @@ -121,9 +121,10 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S
Even though the literals are abstract integers, Naga recognizes
that it is safe and necessary to convert them to `f32` values in
order to build the vector. You can also use abstract values as
initializers for global constants, like this:
initializers for global constants and global and local variables,
like this:

const unit_x: vec2<f32> = vec2(1, 0);
var unit_x: vec2<f32> = vec2(1, 0);

The literals `1` and `0` are abstract integers, and the expression
`vec2(1, 0)` is an abstract vector. However, Naga recognizes that
Expand Down
98 changes: 61 additions & 37 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,30 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ast::GlobalDeclKind::Var(ref v) => {
let ty = self.resolve_ast_type(v.ty, &mut ctx)?;

let init = v
.init
.map(|init| self.expression(init, &mut ctx.as_const()))
.transpose()?;
let init;
if let Some(init_ast) = v.init {
let mut ectx = ctx.as_const();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
.try_automatic_conversions(lowered, &ty_res, v.name.span)
.map_err(|error| match error {
Error::AutoConversion {
dest_span: _,
dest_type,
source_span: _,
source_type,
} => Error::InitializationTypeMismatch {
name: v.name.span,
expected: dest_type,
got: source_type,
},
other => other,
})?;
init = Some(converted);
} else {
init = None;
}

let binding = if let Some(ref binding) = v.binding {
Some(crate::ResourceBinding {
Expand Down Expand Up @@ -1142,45 +1162,49 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(());
}
ast::LocalDecl::Var(ref v) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

let initializer = match v.init {
Some(init) => Some(
self.expression(init, &mut ctx.as_expression(block, &mut emitter))?,
),
None => None,
};

let explicit_ty =
v.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
.transpose()?;

let ty = match (explicit_ty, initializer) {
(Some(explicit), Some(initializer)) => {
let mut ctx = ctx.as_expression(block, &mut emitter);
let initializer_ty = resolve_inner!(ctx, initializer);
if !ctx.module.types[explicit]
.inner
.equivalent(initializer_ty, &ctx.module.types)
{
let gctx = &ctx.module.to_ctx();
return Err(Error::InitializationTypeMismatch {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
let mut ectx = ctx.as_expression(block, &mut emitter);

let ty;
let initializer;
match (v.init, explicit_ty) {
(Some(init), Some(explicit_ty)) => {
let init = self.expression_for_abstract(init, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
let init = ectx
.try_automatic_conversions(init, &ty_res, v.name.span)
.map_err(|error| match error {
Error::AutoConversion {
dest_span: _,
dest_type,
source_span: _,
source_type,
} => Error::InitializationTypeMismatch {
name: v.name.span,
expected: explicit.to_wgsl(gctx),
got: initializer_ty.to_wgsl(gctx),
});
}
explicit
expected: dest_type,
got: source_type,
},
other => other,
})?;
ty = explicit_ty;
initializer = Some(init);
}
(Some(explicit), None) => explicit,
(None, Some(initializer)) => ctx
.as_expression(block, &mut emitter)
.register_type(initializer)?,
(None, None) => {
return Err(Error::MissingType(v.name.span));
(Some(init), None) => {
let concretized = self.expression(init, &mut ectx)?;
ty = ectx.register_type(concretized)?;
initializer = Some(concretized);
}
};
(None, Some(explicit_ty)) => {
ty = explicit_ty;
initializer = None;
}
(None, None) => return Err(Error::MissingType(v.name.span)),
}

let (const_initializer, initializer) = {
match initializer {
Expand Down
120 changes: 120 additions & 0 deletions naga/tests/in/abstract-types-var.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// i/x: type inferred / explicit
// vX/mX/aX: vector / matrix / array of X
// where X: u/i/f: u32 / i32 / f32
// s: vector splat
// r: vector spread (vector arg to vector constructor)
// p: "partial" constructor (type parameter inferred)
// u/i/f/ai/af: u32 / i32 / f32 / abstract float / abstract integer as parameter
// _: just for alignment

// Ensure that:
// - the inferred type is correct.
// - all parameters' types are considered.
// - all parameters are converted to the consensus type.

var<private> xvipaiai: vec2<i32> = vec2(42, 43);
var<private> xvupaiai: vec2<u32> = vec2(44, 45);
var<private> xvfpaiai: vec2<f32> = vec2(46, 47);

var<private> xvupuai: vec2<u32> = vec2(42u, 43);
var<private> xvupaiu: vec2<u32> = vec2(42, 43u);

var<private> xvuuai: vec2<u32> = vec2<u32>(42u, 43);
var<private> xvuaiu: vec2<u32> = vec2<u32>(42, 43u);

var<private> xmfpaiaiaiai: mat2x2<f32> = mat2x2(1, 2, 3, 4);
var<private> xmfpafaiaiai: mat2x2<f32> = mat2x2(1.0, 2, 3, 4);
var<private> xmfpaiafaiai: mat2x2<f32> = mat2x2(1, 2.0, 3, 4);
var<private> xmfpaiaiafai: mat2x2<f32> = mat2x2(1, 2, 3.0, 4);
var<private> xmfpaiaiaiaf: mat2x2<f32> = mat2x2(1, 2, 3, 4.0);

var<private> xvispai: vec2<i32> = vec2(1);
var<private> xvfspaf: vec2<f32> = vec2(1.0);
var<private> xvis_ai: vec2<i32> = vec2<i32>(1);
var<private> xvus_ai: vec2<u32> = vec2<u32>(1);
var<private> xvfs_ai: vec2<f32> = vec2<f32>(1);
var<private> xvfs_af: vec2<f32> = vec2<f32>(1.0);

var<private> xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafaiai: array<f32, 2> = array<f32, 2>(1, 2);

var<private> xafpaiai: array<i32, 2> = array(1, 2);
var<private> xafpaiaf: array<f32, 2> = array(1, 2.0);
var<private> xafpafai: array<f32, 2> = array(1.0, 2);
var<private> xafpafaf: array<f32, 2> = array(1.0, 2.0);

fn all_constant_arguments() {
var xvipaiai: vec2<i32> = vec2(42, 43);
var xvupaiai: vec2<u32> = vec2(44, 45);
var xvfpaiai: vec2<f32> = vec2(46, 47);

var xvupuai: vec2<u32> = vec2(42u, 43);
var xvupaiu: vec2<u32> = vec2(42, 43u);

var xvuuai: vec2<u32> = vec2<u32>(42u, 43);
var xvuaiu: vec2<u32> = vec2<u32>(42, 43u);

var xmfpaiaiaiai: mat2x2<f32> = mat2x2(1, 2, 3, 4);
var xmfpafaiaiai: mat2x2<f32> = mat2x2(1.0, 2, 3, 4);
var xmfpaiafaiai: mat2x2<f32> = mat2x2(1, 2.0, 3, 4);
var xmfpaiaiafai: mat2x2<f32> = mat2x2(1, 2, 3.0, 4);
var xmfpaiaiaiaf: mat2x2<f32> = mat2x2(1, 2, 3, 4.0);

var xmfp_faiaiai: mat2x2<f32> = mat2x2(1.0f, 2, 3, 4);
var xmfpai_faiai: mat2x2<f32> = mat2x2(1, 2.0f, 3, 4);
var xmfpaiai_fai: mat2x2<f32> = mat2x2(1, 2, 3.0f, 4);
var xmfpaiaiai_f: mat2x2<f32> = mat2x2(1, 2, 3, 4.0f);

var xvispai: vec2<i32> = vec2(1);
var xvfspaf: vec2<f32> = vec2(1.0);
var xvis_ai: vec2<i32> = vec2<i32>(1);
var xvus_ai: vec2<u32> = vec2<u32>(1);
var xvfs_ai: vec2<f32> = vec2<f32>(1);
var xvfs_af: vec2<f32> = vec2<f32>(1.0);

var xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var xaf_faf: array<f32, 2> = array<f32, 2>(1.0f, 2.0);
var xafaf_f: array<f32, 2> = array<f32, 2>(1.0, 2.0f);
var xafaiai: array<f32, 2> = array<f32, 2>(1, 2);
var xai_iai: array<i32, 2> = array<i32, 2>(1i, 2);
var xaiai_i: array<i32, 2> = array<i32, 2>(1, 2i);

// Ideally these would infer the var type from the initializer,
// but we don't support that yet.
var xaipaiai: array<i32, 2> = array(1, 2);
var xafpaiai: array<f32, 2> = array(1, 2);
var xafpaiaf: array<f32, 2> = array(1, 2.0);
var xafpafai: array<f32, 2> = array(1.0, 2);
var xafpafaf: array<f32, 2> = array(1.0, 2.0);
}

fn mixed_constant_and_runtime_arguments() {
var u: u32;
var i: i32;
var f: f32;

var xvupuai: vec2<u32> = vec2(u, 43);
var xvupaiu: vec2<u32> = vec2(42, u);

var xvuuai: vec2<u32> = vec2<u32>(u, 43);
var xvuaiu: vec2<u32> = vec2<u32>(42, u);

var xmfp_faiaiai: mat2x2<f32> = mat2x2(f, 2, 3, 4);
var xmfpai_faiai: mat2x2<f32> = mat2x2(1, f, 3, 4);
var xmfpaiai_fai: mat2x2<f32> = mat2x2(1, 2, f, 4);
var xmfpaiaiai_f: mat2x2<f32> = mat2x2(1, 2, 3, f);

var xaf_faf: array<f32, 2> = array<f32, 2>(f, 2.0);
var xafaf_f: array<f32, 2> = array<f32, 2>(1.0, f);
var xaf_fai: array<f32, 2> = array<f32, 2>(f, 2);
var xafai_f: array<f32, 2> = array<f32, 2>(1, f);
var xai_iai: array<i32, 2> = array<i32, 2>(i, 2);
var xaiai_i: array<i32, 2> = array<i32, 2>(1, i);

var xafp_faf: array<f32, 2> = array(f, 2.0);
var xafpaf_f: array<f32, 2> = array(1.0, f);
var xafp_fai: array<f32, 2> = array(f, 2);
var xafpai_f: array<f32, 2> = array(1, f);
var xaip_iai: array<i32, 2> = array(i, 2);
var xaipai_i: array<i32, 2> = array(1, i);
}
117 changes: 117 additions & 0 deletions naga/tests/out/msl/abstract-types-var.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_5 {
float inner[2];
};
struct type_7 {
int inner[2];
};

void all_constant_arguments(
) {
metal::int2 xvipaiai = metal::int2(42, 43);
metal::uint2 xvupaiai = metal::uint2(44u, 45u);
metal::float2 xvfpaiai = metal::float2(46.0, 47.0);
metal::uint2 xvupuai = metal::uint2(42u, 43u);
metal::uint2 xvupaiu = metal::uint2(42u, 43u);
metal::uint2 xvuuai = metal::uint2(42u, 43u);
metal::uint2 xvuaiu = metal::uint2(42u, 43u);
metal::float2x2 xmfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfp_faiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpai_faiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiai_fai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiai_f = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::int2 xvispai = metal::int2(1);
metal::float2 xvfspaf = metal::float2(1.0);
metal::int2 xvis_ai = metal::int2(1);
metal::uint2 xvus_ai = metal::uint2(1u);
metal::float2 xvfs_ai = metal::float2(1.0);
metal::float2 xvfs_af = metal::float2(1.0);
type_5 xafafaf = type_5 {1.0, 2.0};
type_5 xaf_faf = type_5 {1.0, 2.0};
type_5 xafaf_f = type_5 {1.0, 2.0};
type_5 xafaiai = type_5 {1.0, 2.0};
type_7 xai_iai = type_7 {1, 2};
type_7 xaiai_i = type_7 {1, 2};
type_7 xaipaiai = type_7 {1, 2};
type_5 xafpaiai = type_5 {1.0, 2.0};
type_5 xafpaiaf = type_5 {1.0, 2.0};
type_5 xafpafai = type_5 {1.0, 2.0};
type_5 xafpafaf = type_5 {1.0, 2.0};
}

void mixed_constant_and_runtime_arguments(
) {
uint u = {};
int i = {};
float f = {};
metal::uint2 xvupuai_1 = {};
metal::uint2 xvupaiu_1 = {};
metal::uint2 xvuuai_1 = {};
metal::uint2 xvuaiu_1 = {};
metal::float2x2 xmfp_faiaiai_1 = {};
metal::float2x2 xmfpai_faiai_1 = {};
metal::float2x2 xmfpaiai_fai_1 = {};
metal::float2x2 xmfpaiaiai_f_1 = {};
type_5 xaf_faf_1 = {};
type_5 xafaf_f_1 = {};
type_5 xaf_fai = {};
type_5 xafai_f = {};
type_7 xai_iai_1 = {};
type_7 xaiai_i_1 = {};
type_5 xafp_faf = {};
type_5 xafpaf_f = {};
type_5 xafp_fai = {};
type_5 xafpai_f = {};
type_7 xaip_iai = {};
type_7 xaipai_i = {};
uint _e3 = u;
xvupuai_1 = metal::uint2(_e3, 43u);
uint _e7 = u;
xvupaiu_1 = metal::uint2(42u, _e7);
uint _e11 = u;
xvuuai_1 = metal::uint2(_e11, 43u);
uint _e15 = u;
xvuaiu_1 = metal::uint2(42u, _e15);
float _e19 = f;
xmfp_faiaiai_1 = metal::float2x2(metal::float2(_e19, 2.0), metal::float2(3.0, 4.0));
float _e27 = f;
xmfpai_faiai_1 = metal::float2x2(metal::float2(1.0, _e27), metal::float2(3.0, 4.0));
float _e35 = f;
xmfpaiai_fai_1 = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(_e35, 4.0));
float _e43 = f;
xmfpaiaiai_f_1 = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, _e43));
float _e51 = f;
xaf_faf_1 = type_5 {_e51, 2.0};
float _e55 = f;
xafaf_f_1 = type_5 {1.0, _e55};
float _e59 = f;
xaf_fai = type_5 {_e59, 2.0};
float _e63 = f;
xafai_f = type_5 {1.0, _e63};
int _e67 = i;
xai_iai_1 = type_7 {_e67, 2};
int _e71 = i;
xaiai_i_1 = type_7 {1, _e71};
float _e75 = f;
xafp_faf = type_5 {_e75, 2.0};
float _e79 = f;
xafpaf_f = type_5 {1.0, _e79};
float _e83 = f;
xafp_fai = type_5 {_e83, 2.0};
float _e87 = f;
xafpai_f = type_5 {1.0, _e87};
int _e91 = i;
xaip_iai = type_7 {_e91, 2};
int _e95 = i;
xaipai_i = type_7 {1, _e95};
return;
}
Loading
Loading