From 1bf28321132b92fff2cbb1113fc0c8cb799033f6 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 22 Jun 2023 22:24:26 -0700 Subject: [PATCH] Support arrayLength of a dynamically indexed bindings array --- src/back/spv/index.rs | 47 +++++++++++++++++----- src/valid/type.rs | 2 +- tests/in/binding-buffer-arrays.wgsl | 2 + tests/out/spv/binding-buffer-arrays.spvasm | 16 ++++++-- tests/out/wgsl/binding-buffer-arrays.wgsl | 8 +++- 5 files changed, 58 insertions(+), 17 deletions(-) diff --git a/src/back/spv/index.rs b/src/back/spv/index.rs index 4df316bee7..0effb568be 100644 --- a/src/back/spv/index.rs +++ b/src/back/spv/index.rs @@ -48,25 +48,39 @@ impl<'w> BlockContext<'w> { // inside a buffer that is itself an element in a buffer bindings array. // SPIR-V requires that runtime-sized arrays are wrapped in structs. // See `helpers::global_needs_wrapper` and its uses. - let (opt_array_index, global_handle, opt_last_member_index) = match self + let (opt_array_index_id, global_handle, opt_last_member_index) = match self .ir_function .expressions[array] { - // Note that SPIR-V forbids `OpArrayLength` on a variable pointer, - // so we aren't handling `crate::Expression::Access` here. crate::Expression::AccessIndex { base, index } => { match self.ir_function.expressions[base] { // The global variable is an array of buffer bindings of structs, - // and we are accessing the last member. + // we are accessing one of them with a static index, + // and the last member of it. crate::Expression::AccessIndex { base: base_outer, index: index_outer, } => match self.ir_function.expressions[base_outer] { crate::Expression::GlobalVariable(handle) => { - (Some(index_outer), handle, Some(index)) + let index_id = self.get_index_constant(index_outer); + (Some(index_id), handle, Some(index)) } _ => return Err(Error::Validation("array length expression case-1a")), }, + // The global variable is an array of buffer bindings of structs, + // we are accessing one of them with a dynamic index, + // and the last member of it. + crate::Expression::Access { + base: base_outer, + index: index_outer, + } => match self.ir_function.expressions[base_outer] { + crate::Expression::GlobalVariable(handle) => { + let index_id = self.cached[index_outer]; + (Some(index_id), handle, Some(index)) + } + _ => return Err(Error::Validation("array length expression case-1b")), + }, + // The global variable is a buffer, and we are accessing the last member. crate::Expression::GlobalVariable(handle) => { let global = &self.ir_module.global_variables[handle]; match self.ir_module.types[global.ty].inner { @@ -79,15 +93,27 @@ impl<'w> BlockContext<'w> { _ => return Err(Error::Validation("array length expression case-1c")), } } + // The global variable is an array of buffer bindings of arrays. + crate::Expression::Access { base, index } => match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + let index_id = self.cached[index]; + let global = &self.ir_module.global_variables[handle]; + match self.ir_module.types[global.ty].inner { + crate::TypeInner::BindingArray { .. } => (Some(index_id), handle, None), + _ => return Err(Error::Validation("array length expression case-2a")), + } + } + _ => return Err(Error::Validation("array length expression case-2b")), + }, // The global variable is a run-time array. crate::Expression::GlobalVariable(handle) => { let global = &self.ir_module.global_variables[handle]; if !global_needs_wrapper(self.ir_module, global) { - return Err(Error::Validation("array length expression case-2")); + return Err(Error::Validation("array length expression case-3")); } (None, handle, None) } - _ => return Err(Error::Validation("array length expression case-3")), + _ => return Err(Error::Validation("array length expression case-4")), }; let gvar = self.writer.global_variables[global_handle.index()].clone(); @@ -103,17 +129,16 @@ impl<'w> BlockContext<'w> { (0, gvar.var_id) } }; - let structure_id = match opt_array_index { + let structure_id = match opt_array_index_id { // We are indexing inside a binding array, generate the access op. - Some(index) => { + Some(index_id) => { let element_type_id = match self.ir_module.types[global.ty].inner { crate::TypeInner::BindingArray { base, size: _ } => { let class = map_storage_class(global.space); self.get_pointer_id(base, class)? } - _ => return Err(Error::Validation("array length expression case-4")), + _ => return Err(Error::Validation("array length expression case-5")), }; - let index_id = self.get_index_constant(index); let structure_id = self.gen_id(); block.body.push(Instruction::access_chain( element_type_id, diff --git a/src/valid/type.rs b/src/valid/type.rs index f8ceb463c6..906aae2991 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -451,7 +451,6 @@ impl super::Validator { ti.uniform_layout = Ok(Alignment::MIN_UNIFORM); let mut min_offset = 0; - let mut prev_struct_data: Option<(u32, u32)> = None; for (i, member) in members.iter().enumerate() { @@ -585,6 +584,7 @@ impl super::Validator { // Currently Naga only supports binding arrays of structs for non-handle types. match gctx.types[base].inner { crate::TypeInner::Struct { .. } => {} + crate::TypeInner::Array { .. } => {} _ => return Err(TypeError::BindingArrayBaseTypeNotStruct(base)), }; } diff --git a/tests/in/binding-buffer-arrays.wgsl b/tests/in/binding-buffer-arrays.wgsl index e0acc3af48..fb25623962 100644 --- a/tests/in/binding-buffer-arrays.wgsl +++ b/tests/in/binding-buffer-arrays.wgsl @@ -24,6 +24,8 @@ fn main(fragment_in: FragmentIn) -> @location(0) u32 { u1 += storage_array[non_uniform_index].x; u1 += arrayLength(&storage_array[0].far); + u1 += arrayLength(&storage_array[uniform_index].far); + u1 += arrayLength(&storage_array[non_uniform_index].far); return u1; } diff --git a/tests/out/spv/binding-buffer-arrays.spvasm b/tests/out/spv/binding-buffer-arrays.spvasm index 5bc147ecd4..daa411ba2d 100644 --- a/tests/out/spv/binding-buffer-arrays.spvasm +++ b/tests/out/spv/binding-buffer-arrays.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 69 +; Bound: 77 OpCapability Shader OpCapability ShaderNonUniform OpExtension "SPV_KHR_storage_buffer_storage_class" @@ -105,7 +105,17 @@ OpStore %19 %62 %66 = OpLoad %3 %19 %67 = OpIAdd %3 %66 %65 OpStore %19 %67 -%68 = OpLoad %3 %19 -OpStore %27 %68 +%68 = OpAccessChain %40 %12 %38 +%69 = OpArrayLength %3 %68 1 +%70 = OpLoad %3 %19 +%71 = OpIAdd %3 %70 %69 +OpStore %19 %71 +%72 = OpAccessChain %40 %12 %39 +%73 = OpArrayLength %3 %72 1 +%74 = OpLoad %3 %19 +%75 = OpIAdd %3 %74 %73 +OpStore %19 %75 +%76 = OpLoad %3 %19 +OpStore %27 %76 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/binding-buffer-arrays.wgsl b/tests/out/wgsl/binding-buffer-arrays.wgsl index 6aac5d254a..06dcc8d4a7 100644 --- a/tests/out/wgsl/binding-buffer-arrays.wgsl +++ b/tests/out/wgsl/binding-buffer-arrays.wgsl @@ -34,6 +34,10 @@ fn main(fragment_in: FragmentIn) -> @location(0) @interpolate(flat) u32 { u1_ = (_e24 + _e23); let _e31 = u1_; u1_ = (_e31 + arrayLength((&storage_array[0].far))); - let _e33 = u1_; - return _e33; + let _e37 = u1_; + u1_ = (_e37 + arrayLength((&storage_array[uniform_index].far))); + let _e43 = u1_; + u1_ = (_e43 + arrayLength((&storage_array[non_uniform_index].far))); + let _e45 = u1_; + return _e45; }