Skip to content
9 changes: 9 additions & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub enum SpirvAttribute {
DescriptorSet(u32),
Binding(u32),
Flat,
PerPrimitiveExt,
Invariant,
InputAttachmentIndex(u32),
SpecConstant(SpecConstant),
Expand Down Expand Up @@ -128,6 +129,7 @@ pub struct AggregatedSpirvAttributes {
pub binding: Option<Spanned<u32>>,
pub flat: Option<Spanned<()>>,
pub invariant: Option<Spanned<()>>,
pub per_primitive_ext: Option<Spanned<()>>,
pub input_attachment_index: Option<Spanned<u32>>,
pub spec_constant: Option<Spanned<SpecConstant>>,

Expand Down Expand Up @@ -214,6 +216,12 @@ impl AggregatedSpirvAttributes {
Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
PerPrimitiveExt => try_insert(
&mut self.per_primitive_ext,
(),
span,
"#[spirv(per_primitive_ext)]",
),
InputAttachmentIndex(value) => try_insert(
&mut self.input_attachment_index,
value,
Expand Down Expand Up @@ -315,6 +323,7 @@ impl CheckSpirvAttrVisitor<'_> {
| SpirvAttribute::Binding(_)
| SpirvAttribute::Flat
| SpirvAttribute::Invariant
| SpirvAttribute::PerPrimitiveExt
| SpirvAttribute::InputAttachmentIndex(_)
| SpirvAttribute::SpecConstant(_) => match target {
Target::Param => {
Expand Down
32 changes: 32 additions & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,38 @@ impl<'tcx> CodegenCx<'tcx> {
self.emit_global()
.decorate(var_id.unwrap(), Decoration::Invariant, std::iter::empty());
}
if let Some(per_primitive_ext) = attrs.per_primitive_ext {
match execution_model {
ExecutionModel::Fragment => {
if storage_class != Ok(StorageClass::Input) {
self.tcx.dcx().span_fatal(
per_primitive_ext.span,
"`#[spirv(per_primitive_ext)]` in fragment shaders is only valid on Input variables",
);
}
}
ExecutionModel::MeshNV | ExecutionModel::MeshEXT => {
if storage_class != Ok(StorageClass::Output) {
self.tcx.dcx().span_fatal(
per_primitive_ext.span,
"`#[spirv(per_primitive_ext)]` in mesh shaders is only valid on Output variables",
);
}
}
_ => {
self.tcx.dcx().span_fatal(
per_primitive_ext.span,
"`#[spirv(per_primitive_ext)]` is only valid in fragment or mesh shaders",
);
}
}

self.emit_global().decorate(
var_id.unwrap(),
Decoration::PerPrimitiveEXT,
std::iter::empty(),
);
}

let is_subpass_input = match self.lookup_type(value_spirv_type) {
SpirvType::Image {
Expand Down
3 changes: 2 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/simple_passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ pub fn outgoing_edges(block: &Block) -> impl Iterator<Item = Word> + '_ {
| Op::Kill
| Op::Unreachable
| Op::IgnoreIntersectionKHR
| Op::TerminateRayKHR => (0..0).step_by(1),
| Op::TerminateRayKHR
| Op::EmitMeshTasksEXT => (0..0).step_by(1),
_ => panic!("Invalid block terminator: {terminator:?}"),
};
operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref())
Expand Down
4 changes: 1 addition & 3 deletions crates/rustc_codegen_spirv/src/spirv_type_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -954,9 +954,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
reserved!(SPV_INTEL_device_side_avc_motion_estimation);
}
// SPV_EXT_mesh_shader
Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {
reserved!(SPV_EXT_mesh_shader)
}
Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {}
// SPV_NV_ray_tracing_motion_blur
Op::TraceMotionNV | Op::TraceRayMotionNV => reserved!(SPV_NV_ray_tracing_motion_blur),
// SPV_NV_bindless_texture
Expand Down
26 changes: 24 additions & 2 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ const BUILTINS: &[(&str, BuiltIn)] = {
("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
("bary_coord", BaryCoordKHR),
("bary_coord_no_persp", BaryCoordNoPerspKHR),
("primitive_point_indices_ext", PrimitivePointIndicesEXT),
("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
(
"primitive_triangle_indices_ext",
PrimitiveTriangleIndicesEXT,
),
("cull_primitive_ext", CullPrimitiveEXT),
("frag_size_ext", FragSizeEXT),
("frag_invocation_count_ext", FragInvocationCountEXT),
("launch_id", BuiltIn::LaunchIdKHR),
Expand Down Expand Up @@ -169,6 +176,7 @@ const STORAGE_CLASSES: &[(&str, StorageClass)] = {
("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
("physical_storage_buffer", PhysicalStorageBuffer),
("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
]
};

Expand All @@ -183,6 +191,8 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
("compute", GLCompute),
("task_nv", TaskNV),
("mesh_nv", MeshNV),
("task_ext", TaskEXT),
("mesh_ext", MeshEXT),
("ray_generation", ExecutionModel::RayGenerationKHR),
("intersection", ExecutionModel::IntersectionKHR),
("any_hit", ExecutionModel::AnyHitKHR),
Expand Down Expand Up @@ -263,6 +273,17 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
("output_primitives_nv", OutputPrimitivesNV, Value),
("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
("output_triangles_nv", OutputTrianglesNV, None),
("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
(
"output_triangles_ext",
ExecutionMode::OutputTrianglesEXT,
None,
),
(
"output_primitives_ext",
ExecutionMode::OutputPrimitivesEXT,
Value,
),
(
"pixel_interlock_ordered_ext",
PixelInterlockOrderedEXT,
Expand Down Expand Up @@ -334,6 +355,7 @@ impl Symbols {
("block", SpirvAttribute::Block),
("flat", SpirvAttribute::Flat),
("invariant", SpirvAttribute::Invariant),
("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
(
"sampled_image",
SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
Expand Down Expand Up @@ -717,7 +739,7 @@ fn parse_entry_attrs(
.execution_modes
.push((origin_mode, ExecutionModeExtra::new([])));
}
GLCompute | MeshNV | TaskNV => {
GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
if let Some(local_size) = local_size {
entry
.execution_modes
Expand All @@ -726,7 +748,7 @@ fn parse_entry_attrs(
return Err((
arg.span(),
String::from(
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]` or `#[spirv(task_nv)]`",
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
),
));
}
Expand Down
2 changes: 2 additions & 0 deletions crates/spirv-std/src/arch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod atomics;
mod barrier;
mod demote_to_helper_invocation_ext;
mod derivative;
mod mesh_shading;
mod primitive;
mod ray_tracing;
mod subgroup;
Expand All @@ -25,6 +26,7 @@ pub use atomics::*;
pub use barrier::*;
pub use demote_to_helper_invocation_ext::*;
pub use derivative::*;
pub use mesh_shading::*;
pub use primitive::*;
pub use ray_tracing::*;
pub use subgroup::*;
Expand Down
109 changes: 109 additions & 0 deletions crates/spirv-std/src/arch/mesh_shading.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#[cfg(target_arch = "spirv")]
use core::arch::asm;

/// Sets the actual output size of the primitives and vertices that the mesh shader
/// workgroup will emit upon completion.
///
/// 'Vertex Count' must be a 32-bit unsigned integer value.
/// It defines the array size of per-vertex outputs.
///
/// 'Primitive Count' must a 32-bit unsigned integer value.
/// It defines the array size of per-primitive outputs.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction no more than once and under
/// uniform control flow.
/// There must not be any control flow path to an output write that is not preceded
/// by this instruction.
///
/// This instruction is only valid in the *MeshEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpSetMeshOutputsEXT")]
#[inline]
pub unsafe fn set_mesh_outputs_ext(vertex_count: u32, primitive_count: u32) {
asm! {
"OpSetMeshOutputsEXT {vertex_count} {primitive_count}",
vertex_count = in(reg) vertex_count,
primitive_count = in(reg) primitive_count,
}
}

/// Defines the grid size of subsequent mesh shader workgroups to generate
/// upon completion of the task shader workgroup.
///
/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value.
/// They configure the number of local workgroups in each respective dimensions
/// for the launch of child mesh tasks. See Vulkan API specification for more detail.
///
/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations.
/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction exactly once and under uniform
/// control flow.
/// This instruction also serves as an *OpControlBarrier* instruction, and also
/// performs and adheres to the description and semantics of an *OpControlBarrier*
/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and
/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and
/// *AcquireRelease*.
/// Ceases all further processing: Only instructions executed before
/// *OpEmitMeshTasksEXT* have observable side effects.
///
/// This instruction must be the last instruction in a block.
///
/// This instruction is only valid in the *TaskEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpEmitMeshTasksEXT")]
#[inline]
pub unsafe fn emit_mesh_tasks_ext(group_count_x: u32, group_count_y: u32, group_count_z: u32) -> ! {
asm! {
"OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z}",
group_count_x = in(reg) group_count_x,
group_count_y = in(reg) group_count_y,
group_count_z = in(reg) group_count_z,
options(noreturn),
}
}

/// Defines the grid size of subsequent mesh shader workgroups to generate
/// upon completion of the task shader workgroup.
///
/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value.
/// They configure the number of local workgroups in each respective dimensions
/// for the launch of child mesh tasks. See Vulkan API specification for more detail.
///
/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations.
/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction exactly once and under uniform
/// control flow.
/// This instruction also serves as an *OpControlBarrier* instruction, and also
/// performs and adheres to the description and semantics of an *OpControlBarrier*
/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and
/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and
/// *AcquireRelease*.
/// Ceases all further processing: Only instructions executed before
/// *OpEmitMeshTasksEXT* have observable side effects.
///
/// This instruction must be the last instruction in a block.
///
/// This instruction is only valid in the *TaskEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpEmitMeshTasksEXT")]
#[inline]
pub unsafe fn emit_mesh_tasks_ext_payload<T>(
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
payload: &mut T,
) -> ! {
asm! {
"OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z} {payload}",
group_count_x = in(reg) group_count_x,
group_count_y = in(reg) group_count_y,
group_count_z = in(reg) group_count_z,
payload = in(reg) payload,
options(noreturn),
}
}
27 changes: 27 additions & 0 deletions tests/ui/arch/mesh_shader_output_lines.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// build-pass
// only-vulkan1.2
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader

use spirv_std::arch::set_mesh_outputs_ext;
use spirv_std::glam::{UVec2, Vec4};
use spirv_std::spirv;

#[spirv(mesh_ext(
threads(1),
output_vertices = 2,
output_primitives_ext = 1,
output_lines_ext
))]
pub fn main(
#[spirv(position)] positions: &mut [Vec4; 2],
#[spirv(primitive_line_indices_ext)] indices: &mut [UVec2; 1],
) {
unsafe {
set_mesh_outputs_ext(2, 1);
}

positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);
positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0);

indices[0] = UVec2::new(0, 1);
}
26 changes: 26 additions & 0 deletions tests/ui/arch/mesh_shader_output_points.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// build-pass
// only-vulkan1.2
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader

use spirv_std::arch::set_mesh_outputs_ext;
use spirv_std::glam::{UVec2, Vec4};
use spirv_std::spirv;

#[spirv(mesh_ext(
threads(1),
output_vertices = 1,
output_primitives_ext = 1,
output_points
))]
pub fn main(
#[spirv(position)] positions: &mut [Vec4; 1],
#[spirv(primitive_point_indices_ext)] indices: &mut [u32; 1],
) {
unsafe {
set_mesh_outputs_ext(1, 1);
}

positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);

indices[0] = 0;
}
28 changes: 28 additions & 0 deletions tests/ui/arch/mesh_shader_output_triangles.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// build-pass
// only-vulkan1.2
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader

use spirv_std::arch::set_mesh_outputs_ext;
use spirv_std::glam::{UVec3, Vec4};
use spirv_std::spirv;

#[spirv(mesh_ext(
threads(1),
output_vertices = 3,
output_primitives_ext = 1,
output_triangles_ext
))]
pub fn main(
#[spirv(position)] positions: &mut [Vec4; 3],
#[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1],
) {
unsafe {
set_mesh_outputs_ext(3, 1);
}

positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);
positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0);
positions[2] = Vec4::new(0.0, -0.5, 0.0, 1.0);

indices[0] = UVec3::new(0, 1, 2);
}
Loading