From 8f677dc46e0ebf3636353aba1eaa4e13b7141d66 Mon Sep 17 00:00:00 2001 From: Arman Uguray Date: Tue, 31 Jan 2023 20:30:55 -0800 Subject: [PATCH] [msl-out] Introduce a per-entry-point resource binding map option The existing `per_stage_map` field of MSL backend options specifies resource binding maps that apply to all entry points of each stage type. It is useful to have the ability to provide a separate binding index map for each entry point, especially when the same shader module defines multiple entry points of the same stage kind. This patch introduces a new per-entry-point mapping option. When fields are missing in the per-entry-point map, the code falls back to the existing `per_stage_map` option as before. --- src/back/msl/mod.rs | 41 +++++++++++---- src/back/msl/writer.rs | 22 ++++---- tests/in/resource-binding-map.param.ron | 50 ++++++++++++++++++ tests/in/resource-binding-map.wgsl | 20 +++++++ tests/out/msl/resource-binding-map.msl | 70 +++++++++++++++++++++++++ tests/snapshots.rs | 1 + 6 files changed, 184 insertions(+), 20 deletions(-) create mode 100644 tests/in/resource-binding-map.param.ron create mode 100644 tests/in/resource-binding-map.wgsl create mode 100644 tests/out/msl/resource-binding-map.msl diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 271557fbf2..50bc506652 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -101,6 +101,8 @@ impl ops::Index for PerStageMap { } } +pub type PerEntryPointMap = std::collections::BTreeMap; + enum ResolvedBinding { BuiltIn(crate::BuiltIn), Attribute(u32), @@ -200,6 +202,10 @@ pub struct Options { pub lang_version: (u8, u8), /// Map of per-stage resources to slots. pub per_stage_map: PerStageMap, + /// Map of per-stage resources, indexed by entry point function name, to slots. + /// `per_entry_point_map` takes precedence over `per_stage_map` when computing binding slots + /// for an EntryPoint. + pub per_entry_point_map: Option, /// Samplers to be inlined into the code. pub inline_samplers: Vec, /// Make it possible to link different stages via SPIRV-Cross. @@ -218,6 +224,7 @@ impl Default for Options { Options { lang_version: (2, 0), per_stage_map: PerStageMap::default(), + per_entry_point_map: None, inline_samplers: Vec::new(), spirv_cross_compatibility: false, fake_missing_bindings: true, @@ -298,10 +305,16 @@ impl Options { fn resolve_resource_binding( &self, - stage: crate::ShaderStage, + ep: &crate::EntryPoint, res_binding: &crate::ResourceBinding, ) -> Result { - match self.per_stage_map[stage].resources.get(res_binding) { + let target = self + .per_entry_point_map + .as_ref() + .and_then(|map| map.get(&ep.name)) + .and_then(|res| res.resources.get(res_binding)) + .or_else(|| self.per_stage_map[ep.stage].resources.get(res_binding)); + match target { Some(target) => Ok(ResolvedBinding::Resource(target.clone())), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", @@ -312,15 +325,16 @@ impl Options { } } - const fn resolve_push_constants( + fn resolve_push_constants( &self, - stage: crate::ShaderStage, + ep: &crate::EntryPoint, ) -> Result { - let slot = match stage { - crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer, - crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer, - crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer, - }; + let slot = self + .per_entry_point_map + .as_ref() + .and_then(|map| map.get(&ep.name)) + .and_then(|r| r.push_constant_buffer) + .or_else(|| self.per_stage_map[ep.stage].push_constant_buffer); match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), @@ -340,9 +354,14 @@ impl Options { fn resolve_sizes_buffer( &self, - stage: crate::ShaderStage, + ep: &crate::EntryPoint, ) -> Result { - let slot = self.per_stage_map[stage].sizes_buffer; + let slot = self + .per_entry_point_map + .as_ref() + .and_then(|map| map.get(&ep.name)) + .and_then(|r| r.sizes_buffer) + .or_else(|| self.per_stage_map[ep.stage].sizes_buffer); match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index d424c11b20..23b8c41b24 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3368,7 +3368,13 @@ impl Writer { break; } }; - let good = match options.per_stage_map[ep.stage].resources.get(br) { + let target = options + .per_entry_point_map + .as_ref() + .and_then(|map| map.get(&ep.name)) + .and_then(|res| res.resources.get(br)) + .or_else(|| options.per_stage_map[ep.stage].resources.get(br)); + let good = match target { Some(target) => { let binding_ty = match module.types[var.ty].inner { crate::TypeInner::BindingArray { base, .. } => { @@ -3393,7 +3399,7 @@ impl Writer { } } crate::AddressSpace::PushConstant => { - if let Err(e) = options.resolve_push_constants(ep.stage) { + if let Err(e) = options.resolve_push_constants(ep) { ep_error = Some(e); break; } @@ -3404,7 +3410,7 @@ impl Writer { } } if supports_array_length { - if let Err(err) = options.resolve_sizes_buffer(ep.stage) { + if let Err(err) = options.resolve_sizes_buffer(ep) { ep_error = Some(err); } } @@ -3673,15 +3679,13 @@ impl Writer { } // the resolves have already been checked for `!fake_missing_bindings` case let resolved = match var.space { - crate::AddressSpace::PushConstant => { - options.resolve_push_constants(ep.stage).ok() - } + crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(), crate::AddressSpace::WorkGroup => None, crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => { return Err(Error::UnsupportedAddressSpace(var.space)) } _ => options - .resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap()) + .resolve_resource_binding(ep, var.binding.as_ref().unwrap()) .ok(), }; if let Some(ref resolved) = resolved { @@ -3726,7 +3730,7 @@ impl Writer { // passed as a final struct-typed argument. if supports_array_length { // this is checked earlier - let resolved = options.resolve_sizes_buffer(ep.stage).unwrap(); + let resolved = options.resolve_sizes_buffer(ep).unwrap(); let separator = if module.global_variables.is_empty() { ' ' } else { @@ -3786,7 +3790,7 @@ impl Writer { }; } else if let Some(ref binding) = var.binding { // write an inline sampler - let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap(); + let resolved = options.resolve_resource_binding(ep, binding).unwrap(); if let Some(sampler) = resolved.as_inline_sampler(options) { let name = &self.names[&NameKey::GlobalVariable(handle)]; writeln!( diff --git a/tests/in/resource-binding-map.param.ron b/tests/in/resource-binding-map.param.ron new file mode 100644 index 0000000000..a71981d482 --- /dev/null +++ b/tests/in/resource-binding-map.param.ron @@ -0,0 +1,50 @@ +( + god_mode: true, + msl: ( + lang_version: (2, 0), + per_stage_map: ( + fs: ( + resources: { + (group: 0, binding: 0): (texture: Some(0)), + (group: 0, binding: 1): (sampler: Some(Inline(0))), + }, + ) + ), + per_entry_point_map: Some({ + "entry_point_two": ( + resources: { + (group: 0, binding: 0): (texture: Some(1)), + (group: 0, binding: 1): (sampler: Some(Resource(1))), + (group: 0, binding: 2): (buffer: Some(0)), + } + ), + "entry_point_three": ( + resources: { + (group: 0, binding: 2): (buffer: Some(0)), + (group: 1, binding: 0): (buffer: Some(1)), + } + ) + }), + inline_samplers: [ + ( + coord: Normalized, + address: (ClampToEdge, ClampToEdge, ClampToEdge), + mag_filter: Linear, + min_filter: Linear, + mip_filter: None, + border_color: TransparentBlack, + compare_func: Never, + lod_clamp: Some((start: 0.5, end: 10.0)), + max_anisotropy: Some(8), + ), + ], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + bounds_check_policies: ( + index: ReadZeroSkipWrite, + buffer: ReadZeroSkipWrite, + image: ReadZeroSkipWrite, + ) +) diff --git a/tests/in/resource-binding-map.wgsl b/tests/in/resource-binding-map.wgsl new file mode 100644 index 0000000000..76f07a7cbf --- /dev/null +++ b/tests/in/resource-binding-map.wgsl @@ -0,0 +1,20 @@ +@group(0) @binding(0) var t: texture_2d; +@group(0) @binding(1) var s: sampler; + +@group(0) @binding(2) var uniformOne: vec2; +@group(1) @binding(0) var uniformTwo: vec2; + +@fragment +fn entry_point_one(@builtin(position) pos: vec4) -> @location(0) vec4 { + return textureSample(t, s, pos.xy); +} + +@fragment +fn entry_point_two() -> @location(0) vec4 { + return textureSample(t, s, uniformOne); +} + +@fragment +fn entry_point_three() -> @location(0) vec4 { + return textureSample(t, s, uniformTwo + uniformOne); +} diff --git a/tests/out/msl/resource-binding-map.msl b/tests/out/msl/resource-binding-map.msl new file mode 100644 index 0000000000..8eeff2b375 --- /dev/null +++ b/tests/out/msl/resource-binding-map.msl @@ -0,0 +1,70 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + +struct DefaultConstructible { + template + operator T() && { + return T {}; + } +}; + +struct entry_point_oneInput { +}; +struct entry_point_oneOutput { + metal::float4 member [[color(0)]]; +}; +fragment entry_point_oneOutput entry_point_one( + metal::float4 pos [[position]] +, metal::texture2d t [[texture(0)]] +) { + constexpr metal::sampler s( + metal::s_address::clamp_to_edge, + metal::t_address::clamp_to_edge, + metal::r_address::clamp_to_edge, + metal::mag_filter::linear, + metal::min_filter::linear, + metal::coord::normalized + ); + metal::float4 _e4 = t.sample(s, pos.xy); + return entry_point_oneOutput { _e4 }; +} + + +struct entry_point_twoOutput { + metal::float4 member_1 [[color(0)]]; +}; +fragment entry_point_twoOutput entry_point_two( + metal::texture2d t [[texture(1)]] +, metal::sampler s [[sampler(1)]] +, constant metal::float2& uniformOne [[buffer(0)]] +) { + metal::float2 _e3 = uniformOne; + metal::float4 _e4 = t.sample(s, _e3); + return entry_point_twoOutput { _e4 }; +} + + +struct entry_point_threeOutput { + metal::float4 member_2 [[color(0)]]; +}; +fragment entry_point_threeOutput entry_point_three( + metal::texture2d t [[texture(0)]] +, constant metal::float2& uniformOne [[buffer(0)]] +, constant metal::float2& uniformTwo [[buffer(1)]] +) { + constexpr metal::sampler s( + metal::s_address::clamp_to_edge, + metal::t_address::clamp_to_edge, + metal::r_address::clamp_to_edge, + metal::mag_filter::linear, + metal::min_filter::linear, + metal::coord::normalized + ); + metal::float2 _e3 = uniformTwo; + metal::float2 _e5 = uniformOne; + metal::float4 _e7 = t.sample(s, _e3 + _e5); + return entry_point_threeOutput { _e7 }; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 25f307c110..5b98685f3d 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -543,6 +543,7 @@ fn convert_wgsl() { "binding-arrays", Targets::WGSL | Targets::HLSL | Targets::METAL | Targets::SPIRV, ), + ("resource-binding-map", Targets::METAL), ("multiview", Targets::SPIRV | Targets::GLSL | Targets::WGSL), ("multiview_webgl", Targets::GLSL), (