diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs index d3976747ba..8fcf1d6039 100644 --- a/crates/rustc_codegen_spirv/src/attr.rs +++ b/crates/rustc_codegen_spirv/src/attr.rs @@ -92,6 +92,7 @@ pub enum SpirvAttribute { DescriptorSet(u32), Binding(u32), Flat, + PerPrimitiveExt, Invariant, InputAttachmentIndex(u32), SpecConstant(SpecConstant), @@ -128,6 +129,7 @@ pub struct AggregatedSpirvAttributes { pub binding: Option>, pub flat: Option>, pub invariant: Option>, + pub per_primitive_ext: Option>, pub input_attachment_index: Option>, pub spec_constant: Option>, @@ -214,6 +216,12 @@ impl AggregatedSpirvAttributes { Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"), Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"), Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"), + PerPrimitiveExt => try_insert( + &mut self.per_primitive_ext, + (), + span, + "#[spirv(per_primitive_ext)]", + ), InputAttachmentIndex(value) => try_insert( &mut self.input_attachment_index, value, @@ -315,6 +323,7 @@ impl CheckSpirvAttrVisitor<'_> { | SpirvAttribute::Binding(_) | SpirvAttribute::Flat | SpirvAttribute::Invariant + | SpirvAttribute::PerPrimitiveExt | SpirvAttribute::InputAttachmentIndex(_) | SpirvAttribute::SpecConstant(_) => match target { Target::Param => { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 865bcc7a77..97a2b441c9 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -738,6 +738,38 @@ impl<'tcx> CodegenCx<'tcx> { self.emit_global() .decorate(var_id.unwrap(), Decoration::Invariant, std::iter::empty()); } + if let Some(per_primitive_ext) = attrs.per_primitive_ext { + match execution_model { + ExecutionModel::Fragment => { + if storage_class != Ok(StorageClass::Input) { + self.tcx.dcx().span_fatal( + per_primitive_ext.span, + "`#[spirv(per_primitive_ext)]` in fragment shaders is only valid on Input variables", + ); + } + } + ExecutionModel::MeshNV | ExecutionModel::MeshEXT => { + if storage_class != Ok(StorageClass::Output) { + self.tcx.dcx().span_fatal( + per_primitive_ext.span, + "`#[spirv(per_primitive_ext)]` in mesh shaders is only valid on Output variables", + ); + } + } + _ => { + self.tcx.dcx().span_fatal( + per_primitive_ext.span, + "`#[spirv(per_primitive_ext)]` is only valid in fragment or mesh shaders", + ); + } + } + + self.emit_global().decorate( + var_id.unwrap(), + Decoration::PerPrimitiveEXT, + std::iter::empty(), + ); + } let is_subpass_input = match self.lookup_type(value_spirv_type) { SpirvType::Image { diff --git a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs index b7673f1175..082318c5a0 100644 --- a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs +++ b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs @@ -94,7 +94,8 @@ pub fn outgoing_edges(block: &Block) -> impl Iterator + '_ { | Op::Kill | Op::Unreachable | Op::IgnoreIntersectionKHR - | Op::TerminateRayKHR => (0..0).step_by(1), + | Op::TerminateRayKHR + | Op::EmitMeshTasksEXT => (0..0).step_by(1), _ => panic!("Invalid block terminator: {terminator:?}"), }; operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref()) diff --git a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs index c9693b7507..148b44c61f 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs @@ -954,9 +954,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { reserved!(SPV_INTEL_device_side_avc_motion_estimation); } // SPV_EXT_mesh_shader - Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => { - reserved!(SPV_EXT_mesh_shader) - } + Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {} // SPV_NV_ray_tracing_motion_blur Op::TraceMotionNV | Op::TraceRayMotionNV => reserved!(SPV_NV_ray_tracing_motion_blur), // SPV_NV_bindless_texture diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index d19db967e4..db8718b519 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -120,6 +120,13 @@ const BUILTINS: &[(&str, BuiltIn)] = { ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV), ("bary_coord", BaryCoordKHR), ("bary_coord_no_persp", BaryCoordNoPerspKHR), + ("primitive_point_indices_ext", PrimitivePointIndicesEXT), + ("primitive_line_indices_ext", PrimitiveLineIndicesEXT), + ( + "primitive_triangle_indices_ext", + PrimitiveTriangleIndicesEXT, + ), + ("cull_primitive_ext", CullPrimitiveEXT), ("frag_size_ext", FragSizeEXT), ("frag_invocation_count_ext", FragInvocationCountEXT), ("launch_id", BuiltIn::LaunchIdKHR), @@ -169,6 +176,7 @@ const STORAGE_CLASSES: &[(&str, StorageClass)] = { ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR), ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR), ("physical_storage_buffer", PhysicalStorageBuffer), + ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT), ] }; @@ -183,6 +191,8 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = { ("compute", GLCompute), ("task_nv", TaskNV), ("mesh_nv", MeshNV), + ("task_ext", TaskEXT), + ("mesh_ext", MeshEXT), ("ray_generation", ExecutionModel::RayGenerationKHR), ("intersection", ExecutionModel::IntersectionKHR), ("any_hit", ExecutionModel::AnyHitKHR), @@ -263,6 +273,17 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = { ("output_primitives_nv", OutputPrimitivesNV, Value), ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None), ("output_triangles_nv", OutputTrianglesNV, None), + ("output_lines_ext", ExecutionMode::OutputLinesEXT, None), + ( + "output_triangles_ext", + ExecutionMode::OutputTrianglesEXT, + None, + ), + ( + "output_primitives_ext", + ExecutionMode::OutputPrimitivesEXT, + Value, + ), ( "pixel_interlock_ordered_ext", PixelInterlockOrderedEXT, @@ -334,6 +355,7 @@ impl Symbols { ("block", SpirvAttribute::Block), ("flat", SpirvAttribute::Flat), ("invariant", SpirvAttribute::Invariant), + ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt), ( "sampled_image", SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage), @@ -717,7 +739,7 @@ fn parse_entry_attrs( .execution_modes .push((origin_mode, ExecutionModeExtra::new([]))); } - GLCompute | MeshNV | TaskNV => { + GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => { if let Some(local_size) = local_size { entry .execution_modes @@ -726,7 +748,7 @@ fn parse_entry_attrs( return Err(( arg.span(), String::from( - "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]` or `#[spirv(task_nv)]`", + "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`", ), )); } diff --git a/crates/spirv-std/src/arch.rs b/crates/spirv-std/src/arch.rs index c5fde66167..2d9ccdfded 100644 --- a/crates/spirv-std/src/arch.rs +++ b/crates/spirv-std/src/arch.rs @@ -17,6 +17,7 @@ mod atomics; mod barrier; mod demote_to_helper_invocation_ext; mod derivative; +mod mesh_shading; mod primitive; mod ray_tracing; mod subgroup; @@ -25,6 +26,7 @@ pub use atomics::*; pub use barrier::*; pub use demote_to_helper_invocation_ext::*; pub use derivative::*; +pub use mesh_shading::*; pub use primitive::*; pub use ray_tracing::*; pub use subgroup::*; diff --git a/crates/spirv-std/src/arch/mesh_shading.rs b/crates/spirv-std/src/arch/mesh_shading.rs new file mode 100644 index 0000000000..e3806ee86d --- /dev/null +++ b/crates/spirv-std/src/arch/mesh_shading.rs @@ -0,0 +1,109 @@ +#[cfg(target_arch = "spirv")] +use core::arch::asm; + +/// Sets the actual output size of the primitives and vertices that the mesh shader +/// workgroup will emit upon completion. +/// +/// 'Vertex Count' must be a 32-bit unsigned integer value. +/// It defines the array size of per-vertex outputs. +/// +/// 'Primitive Count' must a 32-bit unsigned integer value. +/// It defines the array size of per-primitive outputs. +/// +/// The arguments are taken from the first invocation in each workgroup. +/// Any invocation must execute this instruction no more than once and under +/// uniform control flow. +/// There must not be any control flow path to an output write that is not preceded +/// by this instruction. +/// +/// This instruction is only valid in the *MeshEXT* Execution Model. +#[spirv_std_macros::gpu_only] +#[doc(alias = "OpSetMeshOutputsEXT")] +#[inline] +pub unsafe fn set_mesh_outputs_ext(vertex_count: u32, primitive_count: u32) { + asm! { + "OpSetMeshOutputsEXT {vertex_count} {primitive_count}", + vertex_count = in(reg) vertex_count, + primitive_count = in(reg) primitive_count, + } +} + +/// Defines the grid size of subsequent mesh shader workgroups to generate +/// upon completion of the task shader workgroup. +/// +/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value. +/// They configure the number of local workgroups in each respective dimensions +/// for the launch of child mesh tasks. See Vulkan API specification for more detail. +/// +/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations. +/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*. +/// +/// The arguments are taken from the first invocation in each workgroup. +/// Any invocation must execute this instruction exactly once and under uniform +/// control flow. +/// This instruction also serves as an *OpControlBarrier* instruction, and also +/// performs and adheres to the description and semantics of an *OpControlBarrier* +/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and +/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and +/// *AcquireRelease*. +/// Ceases all further processing: Only instructions executed before +/// *OpEmitMeshTasksEXT* have observable side effects. +/// +/// This instruction must be the last instruction in a block. +/// +/// This instruction is only valid in the *TaskEXT* Execution Model. +#[spirv_std_macros::gpu_only] +#[doc(alias = "OpEmitMeshTasksEXT")] +#[inline] +pub unsafe fn emit_mesh_tasks_ext(group_count_x: u32, group_count_y: u32, group_count_z: u32) -> ! { + asm! { + "OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z}", + group_count_x = in(reg) group_count_x, + group_count_y = in(reg) group_count_y, + group_count_z = in(reg) group_count_z, + options(noreturn), + } +} + +/// Defines the grid size of subsequent mesh shader workgroups to generate +/// upon completion of the task shader workgroup. +/// +/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value. +/// They configure the number of local workgroups in each respective dimensions +/// for the launch of child mesh tasks. See Vulkan API specification for more detail. +/// +/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations. +/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*. +/// +/// The arguments are taken from the first invocation in each workgroup. +/// Any invocation must execute this instruction exactly once and under uniform +/// control flow. +/// This instruction also serves as an *OpControlBarrier* instruction, and also +/// performs and adheres to the description and semantics of an *OpControlBarrier* +/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and +/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and +/// *AcquireRelease*. +/// Ceases all further processing: Only instructions executed before +/// *OpEmitMeshTasksEXT* have observable side effects. +/// +/// This instruction must be the last instruction in a block. +/// +/// This instruction is only valid in the *TaskEXT* Execution Model. +#[spirv_std_macros::gpu_only] +#[doc(alias = "OpEmitMeshTasksEXT")] +#[inline] +pub unsafe fn emit_mesh_tasks_ext_payload( + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + payload: &mut T, +) -> ! { + asm! { + "OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z} {payload}", + group_count_x = in(reg) group_count_x, + group_count_y = in(reg) group_count_y, + group_count_z = in(reg) group_count_z, + payload = in(reg) payload, + options(noreturn), + } +} diff --git a/tests/ui/arch/mesh_shader_output_lines.rs b/tests/ui/arch/mesh_shader_output_lines.rs new file mode 100644 index 0000000000..e2e83694c1 --- /dev/null +++ b/tests/ui/arch/mesh_shader_output_lines.rs @@ -0,0 +1,27 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec2, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 2, + output_primitives_ext = 1, + output_lines_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 2], + #[spirv(primitive_line_indices_ext)] indices: &mut [UVec2; 1], +) { + unsafe { + set_mesh_outputs_ext(2, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0); + + indices[0] = UVec2::new(0, 1); +} diff --git a/tests/ui/arch/mesh_shader_output_points.rs b/tests/ui/arch/mesh_shader_output_points.rs new file mode 100644 index 0000000000..1cc7662ca9 --- /dev/null +++ b/tests/ui/arch/mesh_shader_output_points.rs @@ -0,0 +1,26 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec2, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 1, + output_primitives_ext = 1, + output_points +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 1], + #[spirv(primitive_point_indices_ext)] indices: &mut [u32; 1], +) { + unsafe { + set_mesh_outputs_ext(1, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + + indices[0] = 0; +} diff --git a/tests/ui/arch/mesh_shader_output_triangles.rs b/tests/ui/arch/mesh_shader_output_triangles.rs new file mode 100644 index 0000000000..289f1faf62 --- /dev/null +++ b/tests/ui/arch/mesh_shader_output_triangles.rs @@ -0,0 +1,28 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec3, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 3, + output_primitives_ext = 1, + output_triangles_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 3], + #[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1], +) { + unsafe { + set_mesh_outputs_ext(3, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0); + positions[2] = Vec4::new(0.0, -0.5, 0.0, 1.0); + + indices[0] = UVec3::new(0, 1, 2); +} diff --git a/tests/ui/arch/mesh_shader_payload.rs b/tests/ui/arch/mesh_shader_payload.rs new file mode 100644 index 0000000000..3fdf11f683 --- /dev/null +++ b/tests/ui/arch/mesh_shader_payload.rs @@ -0,0 +1,35 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec3, Vec4}; +use spirv_std::spirv; + +pub struct Payload { + pub first: f32, + pub second: f32, + pub third: f32, +} + +#[spirv(mesh_ext( + threads(1), + output_vertices = 3, + output_primitives_ext = 1, + output_triangles_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 3], + #[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1], + #[spirv(task_payload_workgroup_ext)] payload: &Payload, +) { + unsafe { + set_mesh_outputs_ext(3, 1); + } + + positions[0] = payload.first * Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = payload.second * Vec4::new(0.5, 0.5, 0.0, 1.0); + positions[2] = payload.third * Vec4::new(0.0, -0.5, 0.0, 1.0); + + indices[0] = UVec3::new(0, 1, 2); +} diff --git a/tests/ui/arch/mesh_shader_per_primitive.rs b/tests/ui/arch/mesh_shader_per_primitive.rs new file mode 100644 index 0000000000..29a060cbc0 --- /dev/null +++ b/tests/ui/arch/mesh_shader_per_primitive.rs @@ -0,0 +1,34 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec3, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 3, + output_primitives_ext = 1, + output_triangles_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 3], + out_per_vertex: &mut [u32; 3], + #[spirv(per_primitive_ext)] out_per_primitive: &mut [u32; 1], + #[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1], +) { + unsafe { + set_mesh_outputs_ext(3, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0); + positions[2] = Vec4::new(0.0, -0.5, 0.0, 1.0); + out_per_vertex[0] = 0; + out_per_vertex[1] = 1; + out_per_vertex[2] = 2; + + indices[0] = UVec3::new(0, 1, 2); + out_per_primitive[0] = 42; +} diff --git a/tests/ui/arch/task_shader.rs b/tests/ui/arch/task_shader.rs new file mode 100644 index 0000000000..ea23516869 --- /dev/null +++ b/tests/ui/arch/task_shader.rs @@ -0,0 +1,13 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::emit_mesh_tasks_ext; +use spirv_std::spirv; + +#[spirv(task_ext(threads(1)))] +pub fn main() { + unsafe { + emit_mesh_tasks_ext(1, 2, 3); + } +} diff --git a/tests/ui/arch/task_shader_mispile.rs b/tests/ui/arch/task_shader_mispile.rs new file mode 100644 index 0000000000..ec012789f6 --- /dev/null +++ b/tests/ui/arch/task_shader_mispile.rs @@ -0,0 +1,14 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::emit_mesh_tasks_ext; +use spirv_std::spirv; + +#[spirv(task_ext(threads(1)))] +pub fn main(#[spirv(push_constant)] push: &u32) { + let count = 20 / *push; + unsafe { + emit_mesh_tasks_ext(1, 2, 3); + } +} diff --git a/tests/ui/arch/task_shader_payload.rs b/tests/ui/arch/task_shader_payload.rs new file mode 100644 index 0000000000..30f72a094a --- /dev/null +++ b/tests/ui/arch/task_shader_payload.rs @@ -0,0 +1,21 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::emit_mesh_tasks_ext_payload; +use spirv_std::spirv; + +pub struct Payload { + pub first: u32, + pub second: i32, +} + +#[spirv(task_ext(threads(1)))] +pub fn main(#[spirv(task_payload_workgroup_ext)] payload: &mut Payload) { + payload.first = 1; + payload.second = 2; + + unsafe { + emit_mesh_tasks_ext_payload(3, 4, 5, payload); + } +}