diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index e2664e13bf..7301c4f06b 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -765,7 +765,7 @@ impl Writer { write!(self.out, "[{}]", index)?; } crate::TypeInner::Array { .. } => { - write!(self.out, "[{}]", index)?; + write!(self.out, ".{}[{}]", WRAPPED_ARRAY_FIELD, index)?; } _ => { // unexpected indexing, should fail validation diff --git a/src/front/wgsl/lexer.rs b/src/front/wgsl/lexer.rs index b7b34d2316..015b7b99d7 100644 --- a/src/front/wgsl/lexer.rs +++ b/src/front/wgsl/lexer.rs @@ -233,15 +233,23 @@ impl<'a> Lexer<'a> { token } - pub(super) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> { + pub(super) fn expect_span( + &mut self, + expected: Token<'a>, + ) -> Result, Error<'a>> { let next = self.next(); if next.0 == expected { - Ok(()) + Ok(next.1) } else { Err(Error::Unexpected(next, ExpectedToken::Token(expected))) } } + pub(super) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> { + self.expect_span(expected)?; + Ok(()) + } + pub(super) fn expect_generic_paren(&mut self, expected: char) -> Result<(), Error<'a>> { let next = self.next_generic(); if next.0 == Token::Paren(expected) { diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 65abcd01d1..cee4cab0a3 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -12,7 +12,7 @@ use crate::{ proc::{ ensure_block_returns, Alignment, Layouter, ResolveContext, ResolveError, TypeResolution, }, - FastHashMap, + ConstantInner, FastHashMap, ScalarValue, }; use self::lexer::Lexer; @@ -26,6 +26,7 @@ use codespan_reporting::{ }; use std::{ borrow::Cow, + convert::TryFrom, io::{self, Write}, iter, num::{NonZeroU32, ParseFloatError, ParseIntError}, @@ -98,6 +99,8 @@ pub enum Error<'a> { #[error("")] BadFloat(Span, ParseFloatError), #[error("")] + BadU32Constant(Span), + #[error("")] BadScalarWidth(Span, &'a str), #[error("")] BadAccessor(Span), @@ -250,6 +253,15 @@ impl<'a> Error<'a> { labels: vec![(bad_span.clone(), "expected floating-point literal".into())], notes: vec![err.to_string()], }, + Error::BadU32Constant(ref bad_span) => ParseError { + message: format!( + "expected non-negative integer constant expression, found `{}`", + &source[bad_span.clone()], + ), + labels: vec![(bad_span.clone(), "expected non-negative integer".into())], + notes: vec![], + }, + Error::BadScalarWidth(ref bad_span, width) => ParseError { message: format!("invalid width of `{}` for literal", width,), labels: vec![(bad_span.clone(), "invalid width".into())], @@ -1023,7 +1035,7 @@ impl Parser { ty: char, width: &'a str, token: TokenSpan<'a>, - ) -> Result> { + ) -> Result> { let span = token.1; let value = match ty { 'i' => word @@ -1737,12 +1749,34 @@ impl Parser { } } Token::Paren('[') => { - let _ = lexer.next(); + let (_, open_brace_span) = lexer.next(); let index = self.parse_general_expression(lexer, ctx.reborrow())?; - lexer.expect(Token::Paren(']'))?; - crate::Expression::Access { - base: handle, - index, + let close_brace_span = lexer.expect_span(Token::Paren(']'))?; + + if let crate::Expression::Constant(constant) = ctx.expressions[index] { + let expr_span = open_brace_span.end..close_brace_span.start; + + let index = match ctx.constants[constant].inner { + ConstantInner::Scalar { + value: ScalarValue::Uint(int), + .. + } => u32::try_from(int).map_err(|_| Error::BadU32Constant(expr_span)), + ConstantInner::Scalar { + value: ScalarValue::Sint(int), + .. + } => u32::try_from(int).map_err(|_| Error::BadU32Constant(expr_span)), + _ => Err(Error::BadU32Constant(expr_span)), + }?; + + crate::Expression::AccessIndex { + base: handle, + index, + } + } else { + crate::Expression::Access { + base: handle, + index, + } } } _ => { diff --git a/tests/out/access.msl b/tests/out/access.msl index 4ce5bb71e3..54f422f4da 100644 --- a/tests/out/access.msl +++ b/tests/out/access.msl @@ -28,7 +28,7 @@ vertex fooOutput foo( type6 c; float baz = foo1; foo1 = 1.0; - metal::float4 _e9 = bar.matrix[3u]; + metal::float4 _e9 = bar.matrix[3]; float b = _e9.x; int a = bar.data[(1 + (_buffer_sizes.size0 - 64 - 4) / 4) - 1u]; for(int _i=0; _i<5; ++_i) c.inner[_i] = type6 {a, static_cast(b), 3, 4, 5}.inner[_i]; diff --git a/tests/out/access.wgsl b/tests/out/access.wgsl index f05072bd7b..8bb52234ad 100644 --- a/tests/out/access.wgsl +++ b/tests/out/access.wgsl @@ -14,7 +14,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4 { let baz: f32 = foo1; foo1 = 1.0; - let _e9: vec4 = bar.matrix[3u]; + let _e9: vec4 = bar.matrix[3]; let b: f32 = _e9.x; let a: i32 = bar.data[(arrayLength(&bar.data) - 1u)]; c = array(a, i32(b), 3, 4, 5); diff --git a/tests/out/globals.spvasm b/tests/out/globals.spvasm index 2668d178d0..0e85aa5035 100644 --- a/tests/out/globals.spvasm +++ b/tests/out/globals.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 20 +; Bound: 21 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -22,11 +22,12 @@ OpDecorate %11 ArrayStride 4 %12 = OpVariable %13 Workgroup %16 = OpTypeFunction %2 %18 = OpTypePointer Workgroup %10 +%19 = OpConstant %6 3 %15 = OpFunction %2 None %16 %14 = OpLabel OpBranch %17 %17 = OpLabel -%19 = OpAccessChain %18 %12 %7 -OpStore %19 %9 +%20 = OpAccessChain %18 %12 %19 +OpStore %20 %9 OpReturn OpFunctionEnd \ No newline at end of file