Skip to content

Commit

Permalink
Refactor some shader things and add more validation (#2335)
Browse files Browse the repository at this point in the history
* Refactor some shader things and add more validation

* Remove pub
  • Loading branch information
Rua authored Sep 21, 2023
1 parent e9790c1 commit a8ca0a7
Show file tree
Hide file tree
Showing 7 changed files with 599 additions and 555 deletions.
4 changes: 2 additions & 2 deletions vulkano-shaders/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ mod tests {
let spirv = Spirv::new(&instructions).unwrap();

let mut descriptors = Vec::new();
for info in reflect::entry_points(&spirv) {
for (_, info) in reflect::entry_points(&spirv) {
descriptors.push(info.descriptor_binding_requirements);
}

Expand Down Expand Up @@ -622,7 +622,7 @@ mod tests {
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();

if let Some(info) = reflect::entry_points(&spirv).next() {
if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_binding_requirements {
bindings.push(loc);
Expand Down
6 changes: 3 additions & 3 deletions vulkano/src/pipeline/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
instance::InstanceOwnedDebugWrapper,
macros::impl_id_counter,
pipeline::{cache::PipelineCache, layout::PipelineLayout, Pipeline, PipelineBindPoint},
shader::{DescriptorBindingRequirements, ShaderExecution, ShaderStage},
shader::{spirv::ExecutionModel, DescriptorBindingRequirements, ShaderStage},
Validated, ValidationError, VulkanError, VulkanObject,
};
use ahash::HashMap;
Expand Down Expand Up @@ -155,7 +155,7 @@ impl ComputePipeline {
},
),
flags: flags.into(),
stage: ShaderStage::from(&entry_point_info.execution).into(),
stage: ShaderStage::from(entry_point_info.execution_model).into(),
module: entry_point.module().handle(),
p_name: name_vk.as_ptr(),
p_specialization_info: if specialization_info_vk.data_size == 0 {
Expand Down Expand Up @@ -410,7 +410,7 @@ impl ComputePipelineCreateInfo {

let entry_point_info = entry_point.info();

if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
if !matches!(entry_point_info.execution_model, ExecutionModel::GLCompute) {
return Err(Box::new(ValidationError {
context: "stage.entry_point".into(),
problem: "is not a `ShaderStage::Compute` entry point".into(),
Expand Down
125 changes: 104 additions & 21 deletions vulkano/src/pipeline/graphics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ use crate::{
PartialStateMode,
},
shader::{
DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages,
ShaderExecution, ShaderStage, ShaderStages,
spirv::{ExecutionMode, ExecutionModel, Instruction},
DescriptorBindingRequirements, ShaderStage, ShaderStages,
},
Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, VulkanError, VulkanObject,
};
Expand Down Expand Up @@ -220,7 +220,7 @@ impl GraphicsPipeline {
} = stage;

let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution);
let stage = ShaderStage::from(entry_point_info.execution_model);

let mut specialization_data_vk: Vec<u8> = Vec::new();
let specialization_map_entries_vk: Vec<_> = entry_point
Expand Down Expand Up @@ -1223,15 +1223,28 @@ impl GraphicsPipeline {
} = stage;

let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution);
let stage = ShaderStage::from(entry_point_info.execution_model);
shaders.insert(stage, ());

if let ShaderExecution::Fragment(FragmentShaderExecution {
fragment_tests_stages: s,
..
}) = entry_point_info.execution
{
fragment_tests_stages = Some(s)
let spirv = entry_point.module().spirv();
let entry_point_function = spirv.function(entry_point.id());

if matches!(entry_point_info.execution_model, ExecutionModel::Fragment) {
fragment_tests_stages = Some(FragmentTestsStages::Late);

for instruction in entry_point_function.iter_execution_mode() {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::EarlyFragmentTests => {
fragment_tests_stages = Some(FragmentTestsStages::Early);
}
ExecutionMode::EarlyAndLateFragmentTestsAMD => {
fragment_tests_stages = Some(FragmentTestsStages::EarlyAndLate);
}
_ => (),
}
}
}
}

for (&loc, reqs) in &entry_point_info.descriptor_binding_requirements {
Expand Down Expand Up @@ -1989,7 +2002,7 @@ impl GraphicsPipelineCreateInfo {

for (stage_index, stage) in stages.iter().enumerate() {
let entry_point_info = stage.entry_point.info();
let stage_enum = ShaderStage::from(&entry_point_info.execution);
let stage_enum = ShaderStage::from(entry_point_info.execution_model);
let stage_flag = ShaderStages::from(stage_enum);

if stages_present.intersects(stage_flag) {
Expand Down Expand Up @@ -2081,9 +2094,12 @@ impl GraphicsPipelineCreateInfo {
}

let need_vertex_input_state = need_pre_rasterization_shader_state
&& stages
.iter()
.any(|stage| matches!(stage.entry_point.info().execution, ShaderExecution::Vertex));
&& stages.iter().any(|stage| {
matches!(
stage.entry_point.info().execution_model,
ExecutionModel::Vertex
)
});
let need_fragment_shader_state = need_pre_rasterization_shader_state
&& rasterization_state
.as_ref()
Expand Down Expand Up @@ -2535,8 +2551,8 @@ impl GraphicsPipelineCreateInfo {
problem: format!(
"the output interface of the `ShaderStage::{:?}` stage does not \
match the input interface of the `ShaderStage::{:?}` stage: {}",
ShaderStage::from(&output.entry_point.info().execution),
ShaderStage::from(&input.entry_point.info().execution),
ShaderStage::from(output.entry_point.info().execution_model),
ShaderStage::from(input.entry_point.info().execution_model),
err
)
.into(),
Expand Down Expand Up @@ -2816,11 +2832,30 @@ impl GraphicsPipelineCreateInfo {
geometry_stage,
input_assembly_state,
) {
let entry_point_info = geometry_stage.entry_point.info();
let input = match entry_point_info.execution {
ShaderExecution::Geometry(execution) => execution.input,
_ => unreachable!(),
};
let spirv = geometry_stage.entry_point.module().spirv();
let entry_point_function = spirv.function(geometry_stage.entry_point.id());

let input = entry_point_function
.iter_execution_mode()
.find_map(|instruction| {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::InputPoints => Some(GeometryShaderInput::Points),
ExecutionMode::InputLines => Some(GeometryShaderInput::Lines),
ExecutionMode::InputLinesAdjacency => {
Some(GeometryShaderInput::LinesWithAdjacency)
}
ExecutionMode::Triangles => Some(GeometryShaderInput::Triangles),
ExecutionMode::InputTrianglesAdjacency => {
Some(GeometryShaderInput::TrianglesWithAdjacency)
}
_ => None,
}
} else {
None
}
})
.unwrap();

if let PartialStateMode::Fixed(topology) = input_assembly_state.topology {
if !input.is_compatible_with(topology) {
Expand Down Expand Up @@ -3104,3 +3139,51 @@ impl GraphicsPipelineCreateInfo {
Ok(())
}
}

/// The input primitive type that is expected by a geometry shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum GeometryShaderInput {
Points,
Lines,
LinesWithAdjacency,
Triangles,
TrianglesWithAdjacency,
}

impl GeometryShaderInput {
/// Returns true if the given primitive topology can be used as input for this geometry shader.
#[inline]
fn is_compatible_with(self, topology: PrimitiveTopology) -> bool {
match self {
Self::Points => matches!(topology, PrimitiveTopology::PointList),
Self::Lines => matches!(
topology,
PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
),
Self::LinesWithAdjacency => matches!(
topology,
PrimitiveTopology::LineListWithAdjacency
| PrimitiveTopology::LineStripWithAdjacency
),
Self::Triangles => matches!(
topology,
PrimitiveTopology::TriangleList
| PrimitiveTopology::TriangleStrip
| PrimitiveTopology::TriangleFan,
),
Self::TrianglesWithAdjacency => matches!(
topology,
PrimitiveTopology::TriangleListWithAdjacency
| PrimitiveTopology::TriangleStripWithAdjacency,
),
}
}
}

/// The fragment tests stages that will be executed in a fragment shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FragmentTestsStages {
Early,
Late,
EarlyAndLate,
}
Loading

0 comments on commit a8ca0a7

Please sign in to comment.