Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor some shader things and add more validation #2335

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vulkano-shaders/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ mod tests {
let spirv = Spirv::new(&instructions).unwrap();

let mut descriptors = Vec::new();
for info in reflect::entry_points(&spirv) {
for (_, info) in reflect::entry_points(&spirv) {
descriptors.push(info.descriptor_binding_requirements);
}

Expand Down Expand Up @@ -622,7 +622,7 @@ mod tests {
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();

if let Some(info) = reflect::entry_points(&spirv).next() {
if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_binding_requirements {
bindings.push(loc);
Expand Down
6 changes: 3 additions & 3 deletions vulkano/src/pipeline/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
instance::InstanceOwnedDebugWrapper,
macros::impl_id_counter,
pipeline::{cache::PipelineCache, layout::PipelineLayout, Pipeline, PipelineBindPoint},
shader::{DescriptorBindingRequirements, ShaderExecution, ShaderStage},
shader::{spirv::ExecutionModel, DescriptorBindingRequirements, ShaderStage},
Validated, ValidationError, VulkanError, VulkanObject,
};
use ahash::HashMap;
Expand Down Expand Up @@ -155,7 +155,7 @@ impl ComputePipeline {
},
),
flags: flags.into(),
stage: ShaderStage::from(&entry_point_info.execution).into(),
stage: ShaderStage::from(entry_point_info.execution_model).into(),
module: entry_point.module().handle(),
p_name: name_vk.as_ptr(),
p_specialization_info: if specialization_info_vk.data_size == 0 {
Expand Down Expand Up @@ -410,7 +410,7 @@ impl ComputePipelineCreateInfo {

let entry_point_info = entry_point.info();

if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
if !matches!(entry_point_info.execution_model, ExecutionModel::GLCompute) {
return Err(Box::new(ValidationError {
context: "stage.entry_point".into(),
problem: "is not a `ShaderStage::Compute` entry point".into(),
Expand Down
125 changes: 104 additions & 21 deletions vulkano/src/pipeline/graphics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ use crate::{
PartialStateMode,
},
shader::{
DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages,
ShaderExecution, ShaderStage, ShaderStages,
spirv::{ExecutionMode, ExecutionModel, Instruction},
DescriptorBindingRequirements, ShaderStage, ShaderStages,
},
Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, VulkanError, VulkanObject,
};
Expand Down Expand Up @@ -220,7 +220,7 @@ impl GraphicsPipeline {
} = stage;

let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution);
let stage = ShaderStage::from(entry_point_info.execution_model);

let mut specialization_data_vk: Vec<u8> = Vec::new();
let specialization_map_entries_vk: Vec<_> = entry_point
Expand Down Expand Up @@ -1223,15 +1223,28 @@ impl GraphicsPipeline {
} = stage;

let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution);
let stage = ShaderStage::from(entry_point_info.execution_model);
shaders.insert(stage, ());

if let ShaderExecution::Fragment(FragmentShaderExecution {
fragment_tests_stages: s,
..
}) = entry_point_info.execution
{
fragment_tests_stages = Some(s)
let spirv = entry_point.module().spirv();
let entry_point_function = spirv.function(entry_point.id());

if matches!(entry_point_info.execution_model, ExecutionModel::Fragment) {
fragment_tests_stages = Some(FragmentTestsStages::Late);

for instruction in entry_point_function.iter_execution_mode() {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::EarlyFragmentTests => {
fragment_tests_stages = Some(FragmentTestsStages::Early);
}
ExecutionMode::EarlyAndLateFragmentTestsAMD => {
fragment_tests_stages = Some(FragmentTestsStages::EarlyAndLate);
}
_ => (),
}
}
}
}

for (&loc, reqs) in &entry_point_info.descriptor_binding_requirements {
Expand Down Expand Up @@ -1989,7 +2002,7 @@ impl GraphicsPipelineCreateInfo {

for (stage_index, stage) in stages.iter().enumerate() {
let entry_point_info = stage.entry_point.info();
let stage_enum = ShaderStage::from(&entry_point_info.execution);
let stage_enum = ShaderStage::from(entry_point_info.execution_model);
let stage_flag = ShaderStages::from(stage_enum);

if stages_present.intersects(stage_flag) {
Expand Down Expand Up @@ -2081,9 +2094,12 @@ impl GraphicsPipelineCreateInfo {
}

let need_vertex_input_state = need_pre_rasterization_shader_state
&& stages
.iter()
.any(|stage| matches!(stage.entry_point.info().execution, ShaderExecution::Vertex));
&& stages.iter().any(|stage| {
matches!(
stage.entry_point.info().execution_model,
ExecutionModel::Vertex
)
});
let need_fragment_shader_state = need_pre_rasterization_shader_state
&& rasterization_state
.as_ref()
Expand Down Expand Up @@ -2535,8 +2551,8 @@ impl GraphicsPipelineCreateInfo {
problem: format!(
"the output interface of the `ShaderStage::{:?}` stage does not \
match the input interface of the `ShaderStage::{:?}` stage: {}",
ShaderStage::from(&output.entry_point.info().execution),
ShaderStage::from(&input.entry_point.info().execution),
ShaderStage::from(output.entry_point.info().execution_model),
ShaderStage::from(input.entry_point.info().execution_model),
err
)
.into(),
Expand Down Expand Up @@ -2816,11 +2832,30 @@ impl GraphicsPipelineCreateInfo {
geometry_stage,
input_assembly_state,
) {
let entry_point_info = geometry_stage.entry_point.info();
let input = match entry_point_info.execution {
ShaderExecution::Geometry(execution) => execution.input,
_ => unreachable!(),
};
let spirv = geometry_stage.entry_point.module().spirv();
let entry_point_function = spirv.function(geometry_stage.entry_point.id());

let input = entry_point_function
.iter_execution_mode()
.find_map(|instruction| {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::InputPoints => Some(GeometryShaderInput::Points),
ExecutionMode::InputLines => Some(GeometryShaderInput::Lines),
ExecutionMode::InputLinesAdjacency => {
Some(GeometryShaderInput::LinesWithAdjacency)
}
ExecutionMode::Triangles => Some(GeometryShaderInput::Triangles),
ExecutionMode::InputTrianglesAdjacency => {
Some(GeometryShaderInput::TrianglesWithAdjacency)
}
_ => None,
}
} else {
None
}
})
.unwrap();

if let PartialStateMode::Fixed(topology) = input_assembly_state.topology {
if !input.is_compatible_with(topology) {
Expand Down Expand Up @@ -3104,3 +3139,51 @@ impl GraphicsPipelineCreateInfo {
Ok(())
}
}

/// The input primitive type that is expected by a geometry shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum GeometryShaderInput {
Points,
Lines,
LinesWithAdjacency,
Triangles,
TrianglesWithAdjacency,
}

impl GeometryShaderInput {
/// Returns true if the given primitive topology can be used as input for this geometry shader.
#[inline]
fn is_compatible_with(self, topology: PrimitiveTopology) -> bool {
match self {
Self::Points => matches!(topology, PrimitiveTopology::PointList),
Self::Lines => matches!(
topology,
PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
),
Self::LinesWithAdjacency => matches!(
topology,
PrimitiveTopology::LineListWithAdjacency
| PrimitiveTopology::LineStripWithAdjacency
),
Self::Triangles => matches!(
topology,
PrimitiveTopology::TriangleList
| PrimitiveTopology::TriangleStrip
| PrimitiveTopology::TriangleFan,
),
Self::TrianglesWithAdjacency => matches!(
topology,
PrimitiveTopology::TriangleListWithAdjacency
| PrimitiveTopology::TriangleStripWithAdjacency,
),
}
}
}

/// The fragment tests stages that will be executed in a fragment shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FragmentTestsStages {
Early,
Late,
EarlyAndLate,
}
Loading