Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mesh shaders #44

Merged
merged 8 commits into from
Dec 25, 2024
9 changes: 9 additions & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub enum SpirvAttribute {
DescriptorSet(u32),
Binding(u32),
Flat,
PerPrimitiveExt,
Invariant,
InputAttachmentIndex(u32),
SpecConstant(SpecConstant),
Expand Down Expand Up @@ -128,6 +129,7 @@ pub struct AggregatedSpirvAttributes {
pub binding: Option<Spanned<u32>>,
pub flat: Option<Spanned<()>>,
pub invariant: Option<Spanned<()>>,
pub per_primitive_ext: Option<Spanned<()>>,
pub input_attachment_index: Option<Spanned<u32>>,
pub spec_constant: Option<Spanned<SpecConstant>>,

Expand Down Expand Up @@ -214,6 +216,12 @@ impl AggregatedSpirvAttributes {
Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
PerPrimitiveExt => try_insert(
&mut self.per_primitive_ext,
(),
span,
"#[spirv(per_primitive_ext)]",
),
InputAttachmentIndex(value) => try_insert(
&mut self.input_attachment_index,
value,
Expand Down Expand Up @@ -315,6 +323,7 @@ impl CheckSpirvAttrVisitor<'_> {
| SpirvAttribute::Binding(_)
| SpirvAttribute::Flat
| SpirvAttribute::Invariant
| SpirvAttribute::PerPrimitiveExt
| SpirvAttribute::InputAttachmentIndex(_)
| SpirvAttribute::SpecConstant(_) => match target {
Target::Param => {
Expand Down
32 changes: 32 additions & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,38 @@ impl<'tcx> CodegenCx<'tcx> {
self.emit_global()
.decorate(var_id.unwrap(), Decoration::Invariant, std::iter::empty());
}
if let Some(per_primitive_ext) = attrs.per_primitive_ext {
match execution_model {
ExecutionModel::Fragment => {
if storage_class != Ok(StorageClass::Input) {
self.tcx.dcx().span_fatal(
per_primitive_ext.span,
"`#[spirv(per_primitive_ext)]` in fragment shaders is only valid on Input variables",
);
}
}
ExecutionModel::MeshNV | ExecutionModel::MeshEXT => {
if storage_class != Ok(StorageClass::Output) {
self.tcx.dcx().span_fatal(
per_primitive_ext.span,
"`#[spirv(per_primitive_ext)]` in mesh shaders is only valid on Output variables",
);
}
}
_ => {
self.tcx.dcx().span_fatal(
per_primitive_ext.span,
"`#[spirv(per_primitive_ext)]` is only valid in fragment or mesh shaders",
);
}
}

self.emit_global().decorate(
var_id.unwrap(),
Decoration::PerPrimitiveEXT,
std::iter::empty(),
);
}

let is_subpass_input = match self.lookup_type(value_spirv_type) {
SpirvType::Image {
Expand Down
3 changes: 2 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/simple_passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ pub fn outgoing_edges(block: &Block) -> impl Iterator<Item = Word> + '_ {
| Op::Kill
| Op::Unreachable
| Op::IgnoreIntersectionKHR
| Op::TerminateRayKHR => (0..0).step_by(1),
| Op::TerminateRayKHR
| Op::EmitMeshTasksEXT => (0..0).step_by(1),
_ => panic!("Invalid block terminator: {terminator:?}"),
};
operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref())
Expand Down
4 changes: 1 addition & 3 deletions crates/rustc_codegen_spirv/src/spirv_type_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -954,9 +954,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
reserved!(SPV_INTEL_device_side_avc_motion_estimation);
}
// SPV_EXT_mesh_shader
Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {
reserved!(SPV_EXT_mesh_shader)
}
Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {}
// SPV_NV_ray_tracing_motion_blur
Op::TraceMotionNV | Op::TraceRayMotionNV => reserved!(SPV_NV_ray_tracing_motion_blur),
// SPV_NV_bindless_texture
Expand Down
26 changes: 24 additions & 2 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ const BUILTINS: &[(&str, BuiltIn)] = {
("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
("bary_coord", BaryCoordKHR),
("bary_coord_no_persp", BaryCoordNoPerspKHR),
("primitive_point_indices_ext", PrimitivePointIndicesEXT),
("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
(
"primitive_triangle_indices_ext",
PrimitiveTriangleIndicesEXT,
),
("cull_primitive_ext", CullPrimitiveEXT),
("frag_size_ext", FragSizeEXT),
("frag_invocation_count_ext", FragInvocationCountEXT),
("launch_id", BuiltIn::LaunchIdKHR),
Expand Down Expand Up @@ -169,6 +176,7 @@ const STORAGE_CLASSES: &[(&str, StorageClass)] = {
("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
("physical_storage_buffer", PhysicalStorageBuffer),
("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
]
};

Expand All @@ -183,6 +191,8 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
("compute", GLCompute),
("task_nv", TaskNV),
("mesh_nv", MeshNV),
("task_ext", TaskEXT),
("mesh_ext", MeshEXT),
("ray_generation", ExecutionModel::RayGenerationKHR),
("intersection", ExecutionModel::IntersectionKHR),
("any_hit", ExecutionModel::AnyHitKHR),
Expand Down Expand Up @@ -263,6 +273,17 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
("output_primitives_nv", OutputPrimitivesNV, Value),
("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
("output_triangles_nv", OutputTrianglesNV, None),
("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
(
"output_triangles_ext",
ExecutionMode::OutputTrianglesEXT,
None,
),
(
"output_primitives_ext",
ExecutionMode::OutputPrimitivesEXT,
Value,
),
(
"pixel_interlock_ordered_ext",
PixelInterlockOrderedEXT,
Expand Down Expand Up @@ -334,6 +355,7 @@ impl Symbols {
("block", SpirvAttribute::Block),
("flat", SpirvAttribute::Flat),
("invariant", SpirvAttribute::Invariant),
("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
(
"sampled_image",
SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
Expand Down Expand Up @@ -717,7 +739,7 @@ fn parse_entry_attrs(
.execution_modes
.push((origin_mode, ExecutionModeExtra::new([])));
}
GLCompute | MeshNV | TaskNV => {
GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
if let Some(local_size) = local_size {
entry
.execution_modes
Expand All @@ -726,7 +748,7 @@ fn parse_entry_attrs(
return Err((
arg.span(),
String::from(
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]` or `#[spirv(task_nv)]`",
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
),
));
}
Expand Down
2 changes: 2 additions & 0 deletions crates/spirv-std/src/arch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod atomics;
mod barrier;
mod demote_to_helper_invocation_ext;
mod derivative;
mod mesh_shading;
mod primitive;
mod ray_tracing;
mod subgroup;
Expand All @@ -25,6 +26,7 @@ pub use atomics::*;
pub use barrier::*;
pub use demote_to_helper_invocation_ext::*;
pub use derivative::*;
pub use mesh_shading::*;
pub use primitive::*;
pub use ray_tracing::*;
pub use subgroup::*;
Expand Down
109 changes: 109 additions & 0 deletions crates/spirv-std/src/arch/mesh_shading.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#[cfg(target_arch = "spirv")]
use core::arch::asm;

/// Sets the actual output size of the primitives and vertices that the mesh shader
/// workgroup will emit upon completion.
///
/// 'Vertex Count' must be a 32-bit unsigned integer value.
/// It defines the array size of per-vertex outputs.
///
/// 'Primitive Count' must a 32-bit unsigned integer value.
/// It defines the array size of per-primitive outputs.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction no more than once and under
/// uniform control flow.
/// There must not be any control flow path to an output write that is not preceded
/// by this instruction.
///
/// This instruction is only valid in the *MeshEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpSetMeshOutputsEXT")]
#[inline]
pub unsafe fn set_mesh_outputs_ext(vertex_count: u32, primitive_count: u32) {
asm! {
"OpSetMeshOutputsEXT {vertex_count} {primitive_count}",
vertex_count = in(reg) vertex_count,
primitive_count = in(reg) primitive_count,
}
}

/// Defines the grid size of subsequent mesh shader workgroups to generate
/// upon completion of the task shader workgroup.
///
/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value.
/// They configure the number of local workgroups in each respective dimensions
/// for the launch of child mesh tasks. See Vulkan API specification for more detail.
///
/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations.
/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction exactly once and under uniform
/// control flow.
/// This instruction also serves as an *OpControlBarrier* instruction, and also
/// performs and adheres to the description and semantics of an *OpControlBarrier*
/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and
/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and
/// *AcquireRelease*.
/// Ceases all further processing: Only instructions executed before
/// *OpEmitMeshTasksEXT* have observable side effects.
///
/// This instruction must be the last instruction in a block.
///
/// This instruction is only valid in the *TaskEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpEmitMeshTasksEXT")]
#[inline]
pub unsafe fn emit_mesh_tasks_ext(group_count_x: u32, group_count_y: u32, group_count_z: u32) -> ! {
asm! {
"OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z}",
group_count_x = in(reg) group_count_x,
group_count_y = in(reg) group_count_y,
group_count_z = in(reg) group_count_z,
options(noreturn),
}
}

/// Defines the grid size of subsequent mesh shader workgroups to generate
/// upon completion of the task shader workgroup.
///
/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value.
/// They configure the number of local workgroups in each respective dimensions
/// for the launch of child mesh tasks. See Vulkan API specification for more detail.
///
/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations.
/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*.
///
/// The arguments are taken from the first invocation in each workgroup.
/// Any invocation must execute this instruction exactly once and under uniform
/// control flow.
/// This instruction also serves as an *OpControlBarrier* instruction, and also
/// performs and adheres to the description and semantics of an *OpControlBarrier*
/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and
/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and
/// *AcquireRelease*.
/// Ceases all further processing: Only instructions executed before
/// *OpEmitMeshTasksEXT* have observable side effects.
///
/// This instruction must be the last instruction in a block.
///
/// This instruction is only valid in the *TaskEXT* Execution Model.
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpEmitMeshTasksEXT")]
#[inline]
pub unsafe fn emit_mesh_tasks_ext_payload<T>(
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
payload: &mut T,
) -> ! {
asm! {
"OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z} {payload}",
group_count_x = in(reg) group_count_x,
group_count_y = in(reg) group_count_y,
group_count_z = in(reg) group_count_z,
payload = in(reg) payload,
options(noreturn),
}
}
27 changes: 27 additions & 0 deletions tests/ui/arch/mesh_shader_output_lines.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// build-pass
// only-vulkan1.2
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader

use spirv_std::arch::set_mesh_outputs_ext;
use spirv_std::glam::{UVec2, Vec4};
use spirv_std::spirv;

#[spirv(mesh_ext(
threads(1),
output_vertices = 2,
output_primitives_ext = 1,
output_lines_ext
))]
pub fn main(
#[spirv(position)] positions: &mut [Vec4; 2],
#[spirv(primitive_line_indices_ext)] indices: &mut [UVec2; 1],
) {
unsafe {
set_mesh_outputs_ext(2, 1);
}

positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);
positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0);

indices[0] = UVec2::new(0, 1);
}
26 changes: 26 additions & 0 deletions tests/ui/arch/mesh_shader_output_points.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// build-pass
// only-vulkan1.2
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader

use spirv_std::arch::set_mesh_outputs_ext;
use spirv_std::glam::{UVec2, Vec4};
use spirv_std::spirv;

#[spirv(mesh_ext(
threads(1),
output_vertices = 1,
output_primitives_ext = 1,
output_points
))]
pub fn main(
#[spirv(position)] positions: &mut [Vec4; 1],
#[spirv(primitive_point_indices_ext)] indices: &mut [u32; 1],
) {
unsafe {
set_mesh_outputs_ext(1, 1);
}

positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);

indices[0] = 0;
}
28 changes: 28 additions & 0 deletions tests/ui/arch/mesh_shader_output_triangles.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// build-pass
// only-vulkan1.2
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader

use spirv_std::arch::set_mesh_outputs_ext;
use spirv_std::glam::{UVec3, Vec4};
use spirv_std::spirv;

#[spirv(mesh_ext(
threads(1),
output_vertices = 3,
output_primitives_ext = 1,
output_triangles_ext
))]
pub fn main(
#[spirv(position)] positions: &mut [Vec4; 3],
#[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1],
) {
unsafe {
set_mesh_outputs_ext(3, 1);
}

positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);
positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0);
positions[2] = Vec4::new(0.0, -0.5, 0.0, 1.0);

indices[0] = UVec3::new(0, 1, 2);
}
Loading