Skip to content

Commit

Permalink
[msl-out] Introduce a per-entry-point resource binding map option
Browse files Browse the repository at this point in the history
The existing `per_stage_map` field of MSL backend options specifies
resource binding maps that apply to all entry points of each stage type.
It is useful to have the ability to provide a separate binding index map
for each entry point, especially when the same shader module defines
multiple entry points of the same stage kind.

This patch introduces a new per-entry-point mapping option. When fields
are missing in the per-entry-point map, the code falls back to the
existing `per_stage_map` option as before.
  • Loading branch information
armansito committed Feb 2, 2023
1 parent a5c2cf9 commit 8f677dc
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 20 deletions.
41 changes: 30 additions & 11 deletions src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ impl ops::Index<crate::ShaderStage> for PerStageMap {
}
}

pub type PerEntryPointMap = std::collections::BTreeMap<String, PerStageResources>;

enum ResolvedBinding {
BuiltIn(crate::BuiltIn),
Attribute(u32),
Expand Down Expand Up @@ -200,6 +202,10 @@ pub struct Options {
pub lang_version: (u8, u8),
/// Map of per-stage resources to slots.
pub per_stage_map: PerStageMap,
/// Map of per-stage resources, indexed by entry point function name, to slots.
/// `per_entry_point_map` takes precedence over `per_stage_map` when computing binding slots
/// for an EntryPoint.
pub per_entry_point_map: Option<PerEntryPointMap>,
/// Samplers to be inlined into the code.
pub inline_samplers: Vec<sampler::InlineSampler>,
/// Make it possible to link different stages via SPIRV-Cross.
Expand All @@ -218,6 +224,7 @@ impl Default for Options {
Options {
lang_version: (2, 0),
per_stage_map: PerStageMap::default(),
per_entry_point_map: None,
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
Expand Down Expand Up @@ -298,10 +305,16 @@ impl Options {

fn resolve_resource_binding(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Result<ResolvedBinding, EntryPointError> {
match self.per_stage_map[stage].resources.get(res_binding) {
let target = self
.per_entry_point_map
.as_ref()
.and_then(|map| map.get(&ep.name))
.and_then(|res| res.resources.get(res_binding))
.or_else(|| self.per_stage_map[ep.stage].resources.get(res_binding));
match target {
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
Expand All @@ -312,15 +325,16 @@ impl Options {
}
}

const fn resolve_push_constants(
fn resolve_push_constants(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = match stage {
crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer,
crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer,
crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer,
};
let slot = self
.per_entry_point_map
.as_ref()
.and_then(|map| map.get(&ep.name))
.and_then(|r| r.push_constant_buffer)
.or_else(|| self.per_stage_map[ep.stage].push_constant_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
Expand All @@ -340,9 +354,14 @@ impl Options {

fn resolve_sizes_buffer(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self.per_stage_map[stage].sizes_buffer;
let slot = self
.per_entry_point_map
.as_ref()
.and_then(|map| map.get(&ep.name))
.and_then(|r| r.sizes_buffer)
.or_else(|| self.per_stage_map[ep.stage].sizes_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
Expand Down
22 changes: 13 additions & 9 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3368,7 +3368,13 @@ impl<W: Write> Writer<W> {
break;
}
};
let good = match options.per_stage_map[ep.stage].resources.get(br) {
let target = options
.per_entry_point_map
.as_ref()
.and_then(|map| map.get(&ep.name))
.and_then(|res| res.resources.get(br))
.or_else(|| options.per_stage_map[ep.stage].resources.get(br));
let good = match target {
Some(target) => {
let binding_ty = match module.types[var.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
Expand All @@ -3393,7 +3399,7 @@ impl<W: Write> Writer<W> {
}
}
crate::AddressSpace::PushConstant => {
if let Err(e) = options.resolve_push_constants(ep.stage) {
if let Err(e) = options.resolve_push_constants(ep) {
ep_error = Some(e);
break;
}
Expand All @@ -3404,7 +3410,7 @@ impl<W: Write> Writer<W> {
}
}
if supports_array_length {
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
if let Err(err) = options.resolve_sizes_buffer(ep) {
ep_error = Some(err);
}
}
Expand Down Expand Up @@ -3673,15 +3679,13 @@ impl<W: Write> Writer<W> {
}
// the resolves have already been checked for `!fake_missing_bindings` case
let resolved = match var.space {
crate::AddressSpace::PushConstant => {
options.resolve_push_constants(ep.stage).ok()
}
crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
crate::AddressSpace::WorkGroup => None,
crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
return Err(Error::UnsupportedAddressSpace(var.space))
}
_ => options
.resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap())
.resolve_resource_binding(ep, var.binding.as_ref().unwrap())
.ok(),
};
if let Some(ref resolved) = resolved {
Expand Down Expand Up @@ -3726,7 +3730,7 @@ impl<W: Write> Writer<W> {
// passed as a final struct-typed argument.
if supports_array_length {
// this is checked earlier
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
let resolved = options.resolve_sizes_buffer(ep).unwrap();
let separator = if module.global_variables.is_empty() {
' '
} else {
Expand Down Expand Up @@ -3786,7 +3790,7 @@ impl<W: Write> Writer<W> {
};
} else if let Some(ref binding) = var.binding {
// write an inline sampler
let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap();
let resolved = options.resolve_resource_binding(ep, binding).unwrap();
if let Some(sampler) = resolved.as_inline_sampler(options) {
let name = &self.names[&NameKey::GlobalVariable(handle)];
writeln!(
Expand Down
50 changes: 50 additions & 0 deletions tests/in/resource-binding-map.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
(
god_mode: true,
msl: (
lang_version: (2, 0),
per_stage_map: (
fs: (
resources: {
(group: 0, binding: 0): (texture: Some(0)),
(group: 0, binding: 1): (sampler: Some(Inline(0))),
},
)
),
per_entry_point_map: Some({
"entry_point_two": (
resources: {
(group: 0, binding: 0): (texture: Some(1)),
(group: 0, binding: 1): (sampler: Some(Resource(1))),
(group: 0, binding: 2): (buffer: Some(0)),
}
),
"entry_point_three": (
resources: {
(group: 0, binding: 2): (buffer: Some(0)),
(group: 1, binding: 0): (buffer: Some(1)),
}
)
}),
inline_samplers: [
(
coord: Normalized,
address: (ClampToEdge, ClampToEdge, ClampToEdge),
mag_filter: Linear,
min_filter: Linear,
mip_filter: None,
border_color: TransparentBlack,
compare_func: Never,
lod_clamp: Some((start: 0.5, end: 10.0)),
max_anisotropy: Some(8),
),
],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
zero_initialize_workgroup_memory: true,
),
bounds_check_policies: (
index: ReadZeroSkipWrite,
buffer: ReadZeroSkipWrite,
image: ReadZeroSkipWrite,
)
)
20 changes: 20 additions & 0 deletions tests/in/resource-binding-map.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@group(0) @binding(0) var t: texture_2d<f32>;
@group(0) @binding(1) var s: sampler;

@group(0) @binding(2) var<uniform> uniformOne: vec2<f32>;
@group(1) @binding(0) var<uniform> uniformTwo: vec2<f32>;

@fragment
fn entry_point_one(@builtin(position) pos: vec4<f32>) -> @location(0) vec4<f32> {
return textureSample(t, s, pos.xy);
}

@fragment
fn entry_point_two() -> @location(0) vec4<f32> {
return textureSample(t, s, uniformOne);
}

@fragment
fn entry_point_three() -> @location(0) vec4<f32> {
return textureSample(t, s, uniformTwo + uniformOne);
}
70 changes: 70 additions & 0 deletions tests/out/msl/resource-binding-map.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct DefaultConstructible {
template<typename T>
operator T() && {
return T {};
}
};

struct entry_point_oneInput {
};
struct entry_point_oneOutput {
metal::float4 member [[color(0)]];
};
fragment entry_point_oneOutput entry_point_one(
metal::float4 pos [[position]]
, metal::texture2d<float, metal::access::sample> t [[texture(0)]]
) {
constexpr metal::sampler s(
metal::s_address::clamp_to_edge,
metal::t_address::clamp_to_edge,
metal::r_address::clamp_to_edge,
metal::mag_filter::linear,
metal::min_filter::linear,
metal::coord::normalized
);
metal::float4 _e4 = t.sample(s, pos.xy);
return entry_point_oneOutput { _e4 };
}


struct entry_point_twoOutput {
metal::float4 member_1 [[color(0)]];
};
fragment entry_point_twoOutput entry_point_two(
metal::texture2d<float, metal::access::sample> t [[texture(1)]]
, metal::sampler s [[sampler(1)]]
, constant metal::float2& uniformOne [[buffer(0)]]
) {
metal::float2 _e3 = uniformOne;
metal::float4 _e4 = t.sample(s, _e3);
return entry_point_twoOutput { _e4 };
}


struct entry_point_threeOutput {
metal::float4 member_2 [[color(0)]];
};
fragment entry_point_threeOutput entry_point_three(
metal::texture2d<float, metal::access::sample> t [[texture(0)]]
, constant metal::float2& uniformOne [[buffer(0)]]
, constant metal::float2& uniformTwo [[buffer(1)]]
) {
constexpr metal::sampler s(
metal::s_address::clamp_to_edge,
metal::t_address::clamp_to_edge,
metal::r_address::clamp_to_edge,
metal::mag_filter::linear,
metal::min_filter::linear,
metal::coord::normalized
);
metal::float2 _e3 = uniformTwo;
metal::float2 _e5 = uniformOne;
metal::float4 _e7 = t.sample(s, _e3 + _e5);
return entry_point_threeOutput { _e7 };
}
1 change: 1 addition & 0 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ fn convert_wgsl() {
"binding-arrays",
Targets::WGSL | Targets::HLSL | Targets::METAL | Targets::SPIRV,
),
("resource-binding-map", Targets::METAL),
("multiview", Targets::SPIRV | Targets::GLSL | Targets::WGSL),
("multiview_webgl", Targets::GLSL),
(
Expand Down

0 comments on commit 8f677dc

Please sign in to comment.