Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[msl-out] Replace per_stage_map with per_entry_point_map #2237

Merged
merged 2 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 31 additions & 39 deletions src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ holding the result.
*/

use crate::{arena::Handle, proc::index, valid::ModuleInfo};
use std::{
fmt::{Error as FmtError, Write},
ops,
};
use std::fmt::{Error as FmtError, Write};

mod keywords;
pub mod sampler;
Expand Down Expand Up @@ -69,7 +66,7 @@ pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTar
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct PerStageResources {
pub struct EntryPointResources {
pub resources: BindingMap,

pub push_constant_buffer: Option<Slot>,
Expand All @@ -80,26 +77,7 @@ pub struct PerStageResources {
pub sizes_buffer: Option<Slot>,
}

#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct PerStageMap {
pub vs: PerStageResources,
pub fs: PerStageResources,
pub cs: PerStageResources,
}

impl ops::Index<crate::ShaderStage> for PerStageMap {
type Output = PerStageResources;
fn index(&self, stage: crate::ShaderStage) -> &PerStageResources {
match stage {
crate::ShaderStage::Vertex => &self.vs,
crate::ShaderStage::Fragment => &self.fs,
crate::ShaderStage::Compute => &self.cs,
}
}
}
pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;

enum ResolvedBinding {
BuiltIn(crate::BuiltIn),
Expand Down Expand Up @@ -198,8 +176,8 @@ enum LocationMode {
pub struct Options {
/// (Major, Minor) target version of the Metal Shading Language.
pub lang_version: (u8, u8),
/// Map of per-stage resources to slots.
pub per_stage_map: PerStageMap,
/// Map of entry-point resources, indexed by entry point function name, to slots.
pub per_entry_point_map: EntryPointResourceMap,
/// Samplers to be inlined into the code.
pub inline_samplers: Vec<sampler::InlineSampler>,
/// Make it possible to link different stages via SPIRV-Cross.
Expand All @@ -217,7 +195,7 @@ impl Default for Options {
fn default() -> Self {
Options {
lang_version: (2, 0),
per_stage_map: PerStageMap::default(),
per_entry_point_map: EntryPointResourceMap::default(),
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
Expand Down Expand Up @@ -296,12 +274,26 @@ impl Options {
}
}

fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
self.per_entry_point_map.get(&ep.name)
}

fn get_resource_binding_target(
&self,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Option<&BindTarget> {
self.get_entry_point_resources(ep)
.and_then(|res| res.resources.get(res_binding))
}

fn resolve_resource_binding(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Result<ResolvedBinding, EntryPointError> {
match self.per_stage_map[stage].resources.get(res_binding) {
let target = self.get_resource_binding_target(ep, res_binding);
match target {
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
Expand All @@ -312,15 +304,13 @@ impl Options {
}
}

const fn resolve_push_constants(
fn resolve_push_constants(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = match stage {
crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer,
crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer,
crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer,
};
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.push_constant_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
Expand All @@ -340,9 +330,11 @@ impl Options {

fn resolve_sizes_buffer(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self.per_stage_map[stage].sizes_buffer;
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.sizes_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
Expand Down
17 changes: 8 additions & 9 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3368,7 +3368,8 @@ impl<W: Write> Writer<W> {
break;
}
};
let good = match options.per_stage_map[ep.stage].resources.get(br) {
let target = options.get_resource_binding_target(ep, br);
let good = match target {
Some(target) => {
let binding_ty = match module.types[var.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
Expand All @@ -3393,7 +3394,7 @@ impl<W: Write> Writer<W> {
}
}
crate::AddressSpace::PushConstant => {
if let Err(e) = options.resolve_push_constants(ep.stage) {
if let Err(e) = options.resolve_push_constants(ep) {
ep_error = Some(e);
break;
}
Expand All @@ -3404,7 +3405,7 @@ impl<W: Write> Writer<W> {
}
}
if supports_array_length {
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
if let Err(err) = options.resolve_sizes_buffer(ep) {
ep_error = Some(err);
}
}
Expand Down Expand Up @@ -3673,15 +3674,13 @@ impl<W: Write> Writer<W> {
}
// the resolves have already been checked for `!fake_missing_bindings` case
let resolved = match var.space {
crate::AddressSpace::PushConstant => {
options.resolve_push_constants(ep.stage).ok()
}
crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
crate::AddressSpace::WorkGroup => None,
crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
return Err(Error::UnsupportedAddressSpace(var.space))
}
_ => options
.resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap())
.resolve_resource_binding(ep, var.binding.as_ref().unwrap())
.ok(),
};
if let Some(ref resolved) = resolved {
Expand Down Expand Up @@ -3726,7 +3725,7 @@ impl<W: Write> Writer<W> {
// passed as a final struct-typed argument.
if supports_array_length {
// this is checked earlier
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
let resolved = options.resolve_sizes_buffer(ep).unwrap();
let separator = if module.global_variables.is_empty() {
' '
} else {
Expand Down Expand Up @@ -3786,7 +3785,7 @@ impl<W: Write> Writer<W> {
};
} else if let Some(ref binding) = var.binding {
// write an inline sampler
let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap();
let resolved = options.resolve_resource_binding(ep, binding).unwrap();
if let Some(sampler) = resolved.as_inline_sampler(options) {
let name = &self.names[&NameKey::GlobalVariable(handle)];
writeln!(
Expand Down
10 changes: 5 additions & 5 deletions tests/in/access.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
vs: (
per_entry_point_map: {
"foo_vert": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
Expand All @@ -16,20 +16,20 @@
},
sizes_buffer: Some(24),
),
fs: (
"foo_frag": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
},
sizes_buffer: Some(24),
),
cs: (
"atomics": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
},
sizes_buffer: Some(24),
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/binding-arrays.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
fs: (
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false),
},
sizes_buffer: None,
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: true,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/bitcast.params.ron
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
(
msl: (
lang_version: (1, 2),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
},
sizes_buffer: Some(0),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/bits.param.ron
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
(
msl: (
lang_version: (1, 2),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
},
sizes_buffer: Some(0),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/boids.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
},
sizes_buffer: Some(3),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/extra.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
),
msl: (
lang_version: (2, 2),
per_stage_map: (
fs: (
per_entry_point_map: {
"main": (
push_constant_buffer: Some(1),
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
2 changes: 1 addition & 1 deletion tests/in/interface.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
),
msl: (
lang_version: (2, 1),
per_stage_map: (),
per_entry_point_map: {},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/padding.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
vs: (
per_entry_point_map: {
"vertex": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
(group: 0, binding: 2): (buffer: Some(2), mutable: false),
},
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
Loading