diff --git a/CHANGELOG.md b/CHANGELOG.md index fb156281c0f..13e9ef04e75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,7 +44,7 @@ Bottom level categories: #### Deferred command buffer actions: `map_buffer_on_submit` and `on_submitted_work_done` -You may schedule buffer mapping and a submission-complete callback to run automatically after you submit, directly from encoders, command buffers, and passes. +You may schedule buffer mapping and a submission-complete callback to run automatically after you submit, directly from encoders, command buffers, and passes. ```rust // Record some GPU work so the submission isn't empty and touches `buffer`. @@ -150,7 +150,7 @@ By @cwfitzgerald in [#8163](https://github.com/gfx-rs/wgpu/pull/8163). #### Multi-draw indirect is now unconditionally supported when indirect draws are supported -We have removed `Features::MULTI_DRAW_INDIRECT` as it was unconditionally available on all platforms. +We have removed `Features::MULTI_DRAW_INDIRECT` as it was unconditionally available on all platforms. `RenderPass::multi_draw_indirect` is now available if the device supports downlevel flag `DownlevelFlags::INDIRECT_EXECUTION`. If you are using spirv-passthrough with multi-draw indirect and `gl_DrawID`, you can know if `MULTI_DRAW_INDIRECT` is being emulated @@ -166,6 +166,8 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162). - Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386). +- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251). + ### Changes #### General diff --git a/Cargo.lock b/Cargo.lock index 7ccdeeccbe0..3f6f281d55d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3584,9 +3584,8 @@ dependencies = [ [[package]] name = "rspirv" -version = "0.12.0+sdk-1.3.268.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" +version = "0.12.0+sdk-1.4.309.0" +source = "git+https://github.com/gfx-rs/rspirv?rev=89ce4d0e64c91b0635f617409dc57cb031749a39#89ce4d0e64c91b0635f617409dc57cb031749a39" dependencies = [ "rustc-hash 1.1.0", "spirv", @@ -3961,9 +3960,8 @@ dependencies = [ [[package]] name = "spirv" -version = "0.3.0+sdk-1.3.268.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" +version = "0.3.0+sdk-1.4.309.0" +source = "git+https://github.com/gfx-rs/rspirv?rev=89ce4d0e64c91b0635f617409dc57cb031749a39#89ce4d0e64c91b0635f617409dc57cb031749a39" dependencies = [ "bitflags 2.9.4", "serde", diff --git a/Cargo.toml b/Cargo.toml index bd58ec379ed..24360b6acea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -254,6 +254,8 @@ ndk-sys = "0.6" # These overrides allow our examples to explicitly depend on release crates [patch.crates-io] wgpu = { path = "./wgpu" } +rspirv = { git = "https://github.com/gfx-rs/rspirv", rev = "89ce4d0e64c91b0635f617409dc57cb031749a39" } +spirv = { git = "https://github.com/gfx-rs/rspirv", rev = "89ce4d0e64c91b0635f617409dc57cb031749a39" } [profile.release] lto = "thin" diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 826dad1c219..7be93c90032 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -403,6 +403,16 @@ impl StatementGraph { }, } } + S::CooperativeStore { target, data } => { + self.dependencies.push((id, target, "target")); + self.dependencies.push((id, data.pointer, "pointer")); + self.dependencies.push((id, data.stride, "stride")); + if data.row_major { + "CoopStoreT" + } else { + "CoopStore" + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -742,6 +752,18 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("get{ty}HitVertexPositions").into(), 4) } + E::CooperativeLoad { ref data, .. } => { + edges.insert("pointer", data.pointer); + edges.insert("stride", data.stride); + let suffix = if data.row_major { "T " } else { "" }; + (format!("coopLoad{suffix}").into(), 4) + } + E::CooperativeMultiplyAdd { a, b, c } => { + edges.insert("a", a); + edges.insert("b", b); + edges.insert("c", c); + ("cooperativeMultiplyAdd".into(), 4) + } }; // give uniform expressions an outline diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 4c5a9d8cbcb..3c3b70289a6 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -1107,7 +1107,8 @@ impl<'a, W: Write> Writer<'a, W> { TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?, // Write all variants instead of `_` so that if new variants are added a // no exhaustiveness error is thrown - TypeInner::Pointer { .. } + TypeInner::CooperativeMatrix { .. } + | TypeInner::Pointer { .. } | TypeInner::Struct { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } @@ -2804,6 +2805,7 @@ impl<'a, W: Write> Writer<'a, W> { } writeln!(self.out, ");")?; } + Statement::CooperativeStore { .. } => unimplemented!(), } Ok(()) @@ -4340,7 +4342,9 @@ impl<'a, W: Write> Writer<'a, W> { } // not supported yet Expression::RayQueryGetIntersection { .. } - | Expression::RayQueryVertexPositions { .. } => unreachable!(), + | Expression::RayQueryVertexPositions { .. } + | Expression::CooperativeLoad { .. } + | Expression::CooperativeMultiplyAdd { .. } => unreachable!(), } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index ab95b9327f9..33ad58e17b3 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2747,6 +2747,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } writeln!(self.out, ");")?; } + Statement::CooperativeStore { .. } => unimplemented!(), } Ok(()) @@ -4275,7 +4276,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } // Not supported yet - Expression::RayQueryVertexPositions { .. } => unreachable!(), + Expression::RayQueryVertexPositions { .. } + | Expression::CooperativeLoad { .. } + | Expression::CooperativeMultiplyAdd { .. } => { + unreachable!() + } // Nothing to do here, since call expression already cached Expression::CallResult(_) | Expression::AtomicResult { .. } diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 0d13d63dd9b..092f9d1cd10 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -312,12 +312,10 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { } impl crate::TypeInner { - /// Returns true if this is a handle to a type rather than the type directly. + /// Returns true if a variable of this type is a handle. pub const fn is_handle(&self) -> bool { match *self { - crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } - | crate::TypeInner::AccelerationStructure { .. } => true, + Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true, _ => false, } } diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 64b1280a1b0..dfeba8f8963 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -228,8 +228,10 @@ pub enum Error { UnsupportedArrayOf(String), #[error("array of type '{0:?}' is not supported")] UnsupportedArrayOfType(Handle), - #[error("ray tracing is not supported prior to MSL 2.3")] + #[error("ray tracing is not supported prior to MSL 2.4")] UnsupportedRayTracing, + #[error("cooperative matrix is not supported prior to MSL 2.3")] + UnsupportedCooperativeMatrix, #[error("overrides should not be present at this stage")] Override, #[error("bitcasting to {0:?} is not supported")] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index cec92265416..952e52eda6d 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -78,6 +78,8 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp /// allowing them to be conveniently passed to user-defined or wrapper /// functions. The struct is declared in [`Writer::write_type_defs`]. pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper"; +pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad"; +pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd"; /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// @@ -235,6 +237,21 @@ impl Display for TypeContext<'_> { rows, scalar, } => put_numeric_type(out, scalar, &[rows, columns]), + // Requires Metal-2.3 + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role: _, + } => { + write!( + out, + "{NAMESPACE}::simdgroup_{}{}x{}", + scalar.to_msl_name(), + columns as u32, + rows as u32, + ) + } crate::TypeInner::Pointer { base, space } => { let sub = Self { handle: base, @@ -468,6 +485,19 @@ enum WrappedFunction { ImageQuerySize { class: crate::ImageClass, }, + CooperativeLoad { + space: crate::AddressSpace, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + scalar: crate::Scalar, + }, + CooperativeMultiplyAdd { + space: crate::AddressSpace, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + intermediate: crate::CooperativeSize, + scalar: crate::Scalar, + }, } pub struct Writer { @@ -637,6 +667,7 @@ impl crate::Type { Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::Pointer { .. } | Ti::ValuePointer { .. } => self.name.is_some(), @@ -2818,6 +2849,29 @@ impl Writer { } write!(self.out, "}}")?; } + crate::Expression::CooperativeLoad { ref data, .. } => { + if context.lang_version < (2, 3) { + return Err(Error::UnsupportedCooperativeMatrix); + } + write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?; + write!(self.out, "&")?; + self.put_access_chain(data.pointer, context.policies.index, context)?; + write!(self.out, ", ")?; + self.put_expression(data.stride, context, true)?; + write!(self.out, ", {})", data.row_major)?; + } + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + if context.lang_version < (2, 3) { + return Err(Error::UnsupportedCooperativeMatrix); + } + write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?; + self.put_expression(a, context, true)?; + write!(self.out, ", ")?; + self.put_expression(b, context, true)?; + write!(self.out, ", ")?; + self.put_expression(c, context, true)?; + write!(self.out, ")")?; + } } Ok(()) } @@ -4199,6 +4253,24 @@ impl Writer { } writeln!(self.out, ");")?; } + crate::Statement::CooperativeStore { target, ref data } => { + write!(self.out, "{level}simdgroup_store(")?; + self.put_expression(target, &context.expression, true)?; + write!(self.out, ", &")?; + self.put_access_chain( + data.pointer, + context.expression.policies.index, + &context.expression, + )?; + write!(self.out, ", ")?; + self.put_expression(data.stride, &context.expression, true)?; + if data.row_major { + let matrix_origin = "0"; + let transpose = true; + write!(self.out, ", {matrix_origin}, {transpose}")?; + } + writeln!(self.out, ");")?; + } } } @@ -6255,6 +6327,106 @@ template Ok(()) } + fn write_wrapped_cooperative_load( + &mut self, + module: &crate::Module, + func_ctx: &back::FunctionCtx, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + pointer: Handle, + ) -> BackendResult { + let ptr_ty = func_ctx.resolve_type(pointer, &module.types); + let space = ptr_ty.pointer_space().unwrap(); + let scalar = ptr_ty + .pointer_base_type() + .unwrap() + .inner_with(&module.types) + .scalar() + .unwrap(); + let wrapped = WrappedFunction::CooperativeLoad { + space, + columns, + rows, + scalar, + }; + if !self.wrapped_functions.insert(wrapped) { + return Ok(()); + } + let space_name = space.to_msl_name().unwrap_or_default(); + let scalar_name = scalar.to_msl_name(); + writeln!( + self.out, + "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{", + columns as u32, rows as u32, + )?; + let l1 = back::Level(1); + writeln!( + self.out, + "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;", + columns as u32, rows as u32 + )?; + let matrix_origin = "0"; + writeln!( + self.out, + "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);" + )?; + writeln!(self.out, "{l1}return m;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } + + fn write_wrapped_cooperative_multiply_add( + &mut self, + module: &crate::Module, + func_ctx: &back::FunctionCtx, + space: crate::AddressSpace, + a: Handle, + b: Handle, + ) -> BackendResult { + let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) { + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + .. + } => (columns, rows, scalar), + _ => unreachable!(), + }; + let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) { + crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows), + _ => unreachable!(), + }; + let wrapped = WrappedFunction::CooperativeMultiplyAdd { + space, + columns: b_c, + rows: a_r, + intermediate: a_c, + scalar, + }; + if !self.wrapped_functions.insert(wrapped) { + return Ok(()); + } + let space_name = space.to_msl_name().unwrap_or_default(); + let scalar_name = scalar.to_msl_name(); + writeln!( + self.out, + "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{", + b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32, + )?; + let l1 = back::Level(1); + writeln!( + self.out, + "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;", + b_c as u32, a_r as u32 + )?; + writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?; + writeln!(self.out, "{l1}return d;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } + pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, @@ -6329,6 +6501,24 @@ template crate::Expression::ImageQuery { image, query } => { self.write_wrapped_image_query(module, func_ctx, image, query)?; } + crate::Expression::CooperativeLoad { + columns, + rows, + role: _, + ref data, + } => { + self.write_wrapped_cooperative_load( + module, + func_ctx, + columns, + rows, + data.pointer, + )?; + } + crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => { + let space = crate::AddressSpace::Private; + self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?; + } _ => {} } } @@ -6520,7 +6710,6 @@ template names: &self.names, handle, usage: fun_info[handle], - reference: true, }; let separator = diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d2b3ed70eda..6a5ce44289f 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -633,6 +633,19 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E } => { adjust(query); } + Expression::CooperativeLoad { ref mut data, .. } => { + adjust(&mut data.pointer); + adjust(&mut data.stride); + } + Expression::CooperativeMultiplyAdd { + ref mut a, + ref mut b, + ref mut c, + } => { + adjust(a); + adjust(b); + adjust(c); + } } } @@ -835,6 +848,14 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } + Statement::CooperativeStore { + ref mut target, + ref mut data, + } => { + adjust(target); + adjust(&mut data.pointer); + adjust(&mut data.stride); + } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 7758d86c414..14496957c12 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -19,6 +19,7 @@ fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { crate::TypeInner::Scalar(_) => Dimension::Scalar, crate::TypeInner::Vector { .. } => Dimension::Vector, crate::TypeInner::Matrix { .. } => Dimension::Matrix, + crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix, _ => unreachable!(), } } @@ -766,6 +767,7 @@ impl BlockContext<'_> { rows, scalar, } => { + //TODO: why not just rely on `Fadd` for matrices? self.write_matrix_matrix_column_op( block, id, @@ -781,6 +783,7 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } + crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd, _ => unimplemented!(), }, crate::BinaryOperator::Subtract => match *left_ty_inner { @@ -809,6 +812,7 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } + crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub, _ => unimplemented!(), }, crate::BinaryOperator::Multiply => { @@ -842,10 +846,12 @@ impl BlockContext<'_> { (Dimension::Vector, Dimension::Matrix) => { spirv::Op::VectorTimesMatrix } - (Dimension::Matrix, Dimension::Scalar) => { + (Dimension::Matrix, Dimension::Scalar) + | (Dimension::CooperativeMatrix, Dimension::Scalar) => { spirv::Op::MatrixTimesScalar } - (Dimension::Scalar, Dimension::Matrix) => { + (Dimension::Scalar, Dimension::Matrix) + | (Dimension::Scalar, Dimension::CooperativeMatrix) => { reverse_operands = true; spirv::Op::MatrixTimesScalar } @@ -864,6 +870,12 @@ impl BlockContext<'_> { } (Dimension::Vector, Dimension::Vector) | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, + (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix) + //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication + | (Dimension::CooperativeMatrix, _) + | (_, Dimension::CooperativeMatrix) => { + unimplemented!() + } } } crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() { @@ -1793,6 +1805,57 @@ impl BlockContext<'_> { )?; self.write_ray_query_return_vertex_position(query, block, committed) } + crate::Expression::CooperativeLoad { ref data, .. } => { + self.writer.require_any( + "CooperativeMatrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + let pointer_id = match self.write_access_chain( + data.pointer, + block, + AccessTypeAdjustment::None, + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Copperative load/store out-of-bounds handling", + )); + } + }; + let layout = if data.row_major { + spirv::CooperativeMatrixLayout::RowMajorKHR + } else { + spirv::CooperativeMatrixLayout::ColumnMajorKHR + }; + let layout_id = self.get_index_constant(layout as u32); + let id = self.gen_id(); + block.body.push(Instruction::coop_load( + result_type_id, + id, + pointer_id, + layout_id, + self.cached[data.stride], + )); + id + } + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + self.writer.require_any( + "CooperativeMatrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + let a_id = self.cached[a]; + let b_id = self.cached[b]; + let c_id = self.cached[c]; + let id = self.gen_id(); + block.body.push(Instruction::coop_mul_add( + result_type_id, + id, + a_id, + b_id, + c_id, + )); + id + } }; self.cached[expr_handle] = id; @@ -3654,6 +3717,32 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } + Statement::CooperativeStore { target, ref data } => { + let pointer_id = match self.write_access_chain( + data.pointer, + &mut block, + AccessTypeAdjustment::None, + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Copperative load/store out-of-bounds handling", + )); + } + }; + let layout = if data.row_major { + spirv::CooperativeMatrixLayout::RowMajorKHR + } else { + spirv::CooperativeMatrixLayout::ColumnMajorKHR + }; + let layout_id = self.get_index_constant(layout as u32); + block.body.push(Instruction::coop_store( + self.cached[target], + pointer_id, + layout_id, + self.cached[data.stride], + )); + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 788c3bc119a..22eaa99340f 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -281,6 +281,24 @@ impl super::Instruction { instruction } + pub(super) fn type_coop_matrix( + id: Word, + scalar_type_id: Word, + scope_id: Word, + row_count_id: Word, + column_count_id: Word, + matrix_use_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); + instruction.set_result(id); + instruction.add_operand(scalar_type_id); + instruction.add_operand(scope_id); + instruction.add_operand(row_count_id); + instruction.add_operand(column_count_id); + instruction.add_operand(matrix_use_id); + instruction + } + #[allow(clippy::too_many_arguments)] pub(super) fn type_image( id: Word, @@ -1227,6 +1245,41 @@ impl super::Instruction { instruction } + + // Cooperative operations + pub(super) fn coop_load( + result_type_id: Word, + id: Word, + pointer_id: Word, + layout_id: Word, + stride_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_id); + instruction.add_operand(layout_id); + instruction.add_operand(stride_id); + instruction + } + pub(super) fn coop_store(id: Word, pointer_id: Word, layout_id: Word, stride_id: Word) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR); + instruction.add_operand(pointer_id); + instruction.add_operand(id); + instruction.add_operand(layout_id); + instruction.add_operand(stride_id); + instruction + } + pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(a); + instruction.add_operand(b); + instruction.add_operand(c); + + instruction + } } impl From for spirv::ImageFormat { @@ -1289,3 +1342,13 @@ impl From for spirv::Dim { } } } + +impl From for spirv::CooperativeMatrixUse { + fn from(role: crate::CooperativeRole) -> Self { + match role { + crate::CooperativeRole::A => Self::MatrixAKHR, + crate::CooperativeRole::B => Self::MatrixBKHR, + crate::CooperativeRole::C => Self::MatrixAccumulatorKHR, + } + } +} diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 371b3f7dbec..90f90f30eea 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -340,6 +340,36 @@ impl NumericType { } } +/// A cooperative type, for use in [`LocalType`]. +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum CooperativeType { + Matrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + scalar: crate::Scalar, + role: crate::CooperativeRole, + }, +} + +impl CooperativeType { + const fn from_inner(inner: &crate::TypeInner) -> Option { + match *inner { + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + } => Some(Self::Matrix { + columns, + rows, + scalar, + role, + }), + _ => None, + } + } +} + /// A SPIR-V type constructed during code generation. /// /// This is the variant of [`LookupType`] used to represent types that might not @@ -389,6 +419,7 @@ impl NumericType { enum LocalType { /// A numeric type. Numeric(NumericType), + Cooperative(CooperativeType), Pointer { base: Word, class: spirv::StorageClass, @@ -451,6 +482,7 @@ enum Dimension { Scalar, Vector, Matrix, + CooperativeMatrix, } /// Key used to look up an operation which we have wrapped in a helper diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 636766d1e5f..88c3a1629ce 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -6,10 +6,11 @@ use spirv::Word; use super::{ block::DebugInfoInner, helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, - Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error, - Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType, - LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options, - PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, + Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo, + EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, + LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, + NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, + BITS_PER_BYTE, }; use crate::{ arena::{Handle, HandleVec, UniqueArena}, @@ -436,6 +437,9 @@ impl Writer { // these cases, so unwrap. LocalType::Numeric(NumericType::from_inner(inner).unwrap()) } + crate::TypeInner::CooperativeMatrix { .. } => { + LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap()) + } crate::TypeInner::Pointer { base, space } => { let base_type_id = self.get_handle_type_id(base); LocalType::Pointer { @@ -967,14 +971,13 @@ impl Writer { } } - // Handle globals are pre-emitted and should be loaded automatically. - // - // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. match ir_module.types[var.ty].inner { + // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. crate::TypeInner::BindingArray { .. } => { gv.access_id = gv.var_id; } _ => { + // Handle globals are pre-emitted and should be loaded automatically. if var.space == crate::AddressSpace::Handle { let var_type_id = self.get_handle_type_id(var.ty); let id = self.id_gen.next(); @@ -1060,6 +1063,7 @@ impl Writer { } }), ); + context .function .variables @@ -1352,6 +1356,16 @@ impl Writer { self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?; self.use_extension("SPV_KHR_16bit_storage"); } + // Cooperative types and ops + crate::TypeInner::CooperativeMatrix { .. } => { + self.require_any( + "cooperative matrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?; + self.use_extension("SPV_KHR_cooperative_matrix"); + self.use_extension("SPV_KHR_vulkan_memory_model"); + } _ => {} } Ok(()) @@ -1378,12 +1392,38 @@ impl Writer { instruction.to_words(&mut self.logical_layout.declarations); } + fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) { + let instruction = match coop { + CooperativeType::Matrix { + columns, + rows, + scalar, + role, + } => { + let scalar_id = + self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar))); + let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let columns_id = self.get_index_constant(columns as u32); + let rows_id = self.get_index_constant(rows as u32); + let role_id = + self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32); + Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id) + } + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) { let instruction = match local_ty { LocalType::Numeric(numeric) => { self.write_numeric_type_declaration_local(id, numeric); return; } + LocalType::Cooperative(coop) => { + self.write_cooperative_type_declaration_local(id, coop); + return; + } LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base), LocalType::Image(image) => { let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type)); @@ -1500,6 +1540,7 @@ impl Writer { | crate::TypeInner::Atomic(_) | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Image { .. } @@ -2630,7 +2671,14 @@ impl Writer { } let addressing_model = spirv::AddressingModel::Logical; - let memory_model = spirv::MemoryModel::GLSL450; + let memory_model = if self + .capabilities_used + .contains(&spirv::Capability::VulkanMemoryModel) + { + spirv::MemoryModel::Vulkan + } else { + spirv::MemoryModel::GLSL450 + }; //self.check(addressing_model.required_capabilities())?; //self.check(memory_model.required_capabilities())?; diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 225a63343bf..4e81a2f7cd8 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -984,6 +984,16 @@ impl Writer { } writeln!(self.out, ");")?; } + Statement::CooperativeStore { target, ref data } => { + let suffix = if data.row_major { "T" } else { "" }; + write!(self.out, "{level}coopStore{suffix}(")?; + self.write_expr(module, target, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, data.pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, data.stride, func_ctx)?; + writeln!(self.out, ");")? + } } Ok(()) @@ -1101,6 +1111,13 @@ impl Writer { // If the plain form of the expression is not what we need, emit the // operator necessary to correct that. let plain = self.plain_form_indirection(expr, module, func_ctx); + log::trace!( + "expression {:?}={:?} is {:?}, expected {:?}", + expr, + func_ctx.expressions[expr], + plain, + requested, + ); match (requested, plain) { (Indirection::Ordinary, Indirection::Reference) => { write!(self.out, "(&")?; @@ -1695,6 +1712,43 @@ impl Writer { | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} + Expression::CooperativeLoad { + columns, + rows, + role, + ref data, + } => { + let suffix = if data.row_major { "T" } else { "" }; + let scalar = func_ctx.info[data.pointer] + .ty + .inner_with(&module.types) + .pointer_base_type() + .unwrap() + .inner_with(&module.types) + .scalar() + .unwrap(); + write!( + self.out, + "coopLoad{suffix}>(", + columns as u32, + rows as u32, + scalar.try_to_wgsl().unwrap(), + role, + )?; + self.write_expr(module, data.pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, data.stride, func_ctx)?; + write!(self.out, ")")?; + } + Expression::CooperativeMultiplyAdd { a, b, c } => { + write!(self.out, "coopMultiplyAdd(")?; + self.write_expr(module, a, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, b, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, c, func_ctx)?; + write!(self.out, ")")?; + } } Ok(()) diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 72be441288f..1cdf3eb5cff 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -299,15 +299,23 @@ impl TryToWgsl for crate::Scalar { } } -impl ToWgsl for crate::ImageDimension { +impl ToWgsl for crate::CooperativeRole { fn to_wgsl(self) -> &'static str { - use crate::ImageDimension as IDim; + match self { + Self::A => "A", + Self::B => "B", + Self::C => "C", + } + } +} +impl ToWgsl for crate::ImageDimension { + fn to_wgsl(self) -> &'static str { match self { - IDim::D1 => "1d", - IDim::D2 => "2d", - IDim::D3 => "3d", - IDim::Cube => "cube", + Self::D1 => "1d", + Self::D2 => "2d", + Self::D3 => "3d", + Self::Cube => "cube", } } } diff --git a/naga/src/common/wgsl/types.rs b/naga/src/common/wgsl/types.rs index 82b8eeaa67a..a678a617f76 100644 --- a/naga/src/common/wgsl/types.rs +++ b/naga/src/common/wgsl/types.rs @@ -317,6 +317,21 @@ where ctx.write_scalar(scalar, out)?; out.write_str(">")?; } + TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + } => { + write!( + out, + "coop_mat{}x{}<{},{}>", + columns as u32, + rows as u32, + scalar.try_to_wgsl().unwrap_or_default(), + role.to_wgsl(), + )?; + } TypeInner::Pointer { base, space } => { let (address, maybe_access) = address_space_str(space); // Everything but `AddressSpace::Handle` gives us a `address` name, but diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index f36d747a935..021401d00ef 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -253,6 +253,15 @@ impl ExpressionTracer<'_> { } => { self.expressions_used.insert(query); } + Ex::CooperativeLoad { ref data, .. } => { + self.expressions_used.insert(data.pointer); + self.expressions_used.insert(data.stride); + } + Ex::CooperativeMultiplyAdd { a, b, c } => { + self.expressions_used.insert(a); + self.expressions_used.insert(b); + self.expressions_used.insert(c); + } } } } @@ -419,6 +428,19 @@ impl ModuleMap { ref mut query, committed: _, } => adjust(query), + Ex::CooperativeLoad { ref mut data, .. } => { + adjust(&mut data.pointer); + adjust(&mut data.stride); + } + Ex::CooperativeMultiplyAdd { + ref mut a, + ref mut b, + ref mut c, + } => { + adjust(a); + adjust(b); + adjust(c); + } } } diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 39d6065f5f0..af72cb872ae 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -152,6 +152,11 @@ impl FunctionTracer<'_> { self.expressions_used.insert(argument); self.expressions_used.insert(result); } + St::CooperativeStore { target, ref data } => { + self.expressions_used.insert(target); + self.expressions_used.insert(data.pointer); + self.expressions_used.insert(data.stride); + } // Trivial statements. St::Break @@ -371,6 +376,14 @@ impl FunctionMap { adjust(argument); adjust(result); } + St::CooperativeStore { + ref mut target, + ref mut data, + } => { + adjust(target); + adjust(&mut data.pointer); + adjust(&mut data.stride); + } // Trivial statements. St::Break diff --git a/naga/src/compact/types.rs b/naga/src/compact/types.rs index 0a1db16f9f6..d06558b182c 100644 --- a/naga/src/compact/types.rs +++ b/naga/src/compact/types.rs @@ -16,6 +16,7 @@ impl TypeTracer<'_> { Ti::Scalar { .. } | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic { .. } | Ti::ValuePointer { .. } | Ti::Image { .. } @@ -66,6 +67,7 @@ impl ModuleMap { Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::ValuePointer { .. } | Ti::Image { .. } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 803e52553dc..aa29bbffae6 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4654,6 +4654,7 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), + S::CooperativeStore { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 17dab5cb0ea..b4c5c9b99c1 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -412,6 +412,9 @@ pub(crate) enum Error<'a> { TypeTooLarge { span: Span, }, + UnderspecifiedCooperativeMatrix, + InvalidCooperativeLoadType(Span), + UnsupportedCooperativeScalar(Span), } impl From for Error<'_> { @@ -1386,6 +1389,21 @@ impl<'a> Error<'a> { crate::valid::MAX_TYPE_SIZE )], }, + Error::UnderspecifiedCooperativeMatrix => ParseError { + message: "cooperative matrix constructor is underspecified".into(), + labels: vec![], + notes: vec![format!("must be F32")], + }, + Error::InvalidCooperativeLoadType(span) => ParseError { + message: "cooperative load should have a generic type for coop_mat".into(), + labels: vec![(span, "type needs the coop_mat<...>".into())], + notes: vec![format!("must be a valid cooperative type")], + }, + Error::UnsupportedCooperativeScalar(span) => ParseError { + message: "cooperative scalar type is not supported".into(), + labels: vec![(span, "type needs the scalar type specified".into())], + notes: vec![format!("must be F32")], + }, } } } diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 997d5a31238..2159ef01ad0 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -638,6 +638,29 @@ impl<'source> Lowerer<'source, '_> { }; Constructor::Type(ty) } + ast::ConstructorType::PartialCooperativeMatrix { .. } => { + return Err(Box::new(Error::UnderspecifiedCooperativeMatrix)); + } + ast::ConstructorType::CooperativeMatrix { + rows, + columns, + ty, + ty_span, + role, + } => { + let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?; + let scalar = match ctx.module.types[ty].inner { + crate::TypeInner::Scalar(s) => s, + _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))), + }; + let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + }); + Constructor::Type(ty) + } ast::ConstructorType::PartialArray => Constructor::PartialArray, ast::ConstructorType::Array { base, size } => { let base = self.resolve_ast_type(base, &mut ctx.as_const())?; diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs index b22692a3cd9..9e03ed5c9e2 100644 --- a/naga/src/front/wgsl/lower/conversion.rs +++ b/naga/src/front/wgsl/lower/conversion.rs @@ -350,6 +350,7 @@ impl crate::TypeInner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } + Ti::CooperativeMatrix { .. } => None, Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), Ti::Atomic(_) | Ti::Pointer { .. } @@ -375,6 +376,7 @@ impl crate::TypeInner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } + Ti::CooperativeMatrix { .. } => None, Ti::Atomic(_) => None, Ti::Pointer { base, .. } | Ti::Array { base, .. } => { types[base].inner.automatically_convertible_scalar(types) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e90d7eab0a8..326455c8ffe 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -524,6 +524,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { span: Span, ) -> Result<'source, Handle> { let mut eval = self.as_const_evaluator(); + log::debug!("appending {expr:?}"); eval.try_eval_and_append(expr, span) .map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span))) } @@ -846,6 +847,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle { self.as_global().ensure_type_exists(None, inner) } + + fn _get_runtime_expression(&self, expr: Handle) -> &ir::Expression { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr], + ExpressionContextType::Constant(_) | ExpressionContextType::Override => { + unreachable!() + } + } + } } struct ArgumentContext<'ctx, 'source> { @@ -955,6 +965,13 @@ impl Typed { Self::Plain(expr) => Typed::Plain(f(expr)?), }) } + + fn ref_or(self, error: E) -> core::result::Result { + match self { + Self::Reference(v) => Ok(v), + Self::Plain(_) => Err(error), + } + } } /// A single vector component or swizzle. @@ -1888,6 +1905,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { stmt.span, function, arguments, + None, &mut ctx.as_expression(block, &mut emitter), true, )?; @@ -1973,12 +1991,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value_span = ctx.ast_expressions.get_span(value); let target = self .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; - let target_handle = match target { - Typed::Reference(handle) => handle, - Typed::Plain(_) => { - return Err(Box::new(Error::BadIncrDecrReferenceType(value_span))) - } - }; + let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?; let mut ectx = ctx.as_expression(block, &mut emitter); let scalar = match *resolve_inner!(ectx, target_handle) { @@ -2134,7 +2147,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expr = match *global { LoweredGlobalDecl::Var(handle) => { let expr = ir::Expression::GlobalVariable(handle); - match ctx.module.global_variables[handle].space { + let v = &ctx.module.global_variables[handle]; + match v.space { ir::AddressSpace::Handle => Typed::Plain(expr), _ => Typed::Reference(expr), } @@ -2214,9 +2228,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Expression::Call { ref function, ref arguments, + result_ty, } => { let handle = self - .call(span, function, arguments, ctx, false)? + .call(span, function, arguments, result_ty, ctx, false)? .ok_or(Error::FunctionReturnsVoid(function.span))?; return Ok(Typed::Plain(handle)); } @@ -2411,6 +2426,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span: Span, function: &ast::Ident<'source>, arguments: &[Handle>], + result_ty: Option<(Handle>, Span)>, ctx: &mut ExpressionContext<'source, '_, '_>, is_statement: bool, ) -> Result<'source, Option>> { @@ -3082,7 +3098,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "quadSwapY" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3106,7 +3121,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "quadSwapDiagonal" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3130,6 +3144,101 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } + "coopLoad" | "coopLoadT" => { + let row_major = function.name.ends_with("T"); + let mut args = ctx.prepare_args(arguments, 1, span); + let pointer = self.expression(args.next()?, ctx)?; + let (matrix_ty, matrix_span) = result_ty.expect("generic argument"); + let (columns, rows, role) = match ctx.types[matrix_ty] { + ast::Type::CooperativeMatrix { + columns, + rows, + role, + .. + } => (columns, rows, role), + _ => { + return Err(Box::new(Error::InvalidCooperativeLoadType( + matrix_span, + ))) + } + }; + let stride = if args.total_args > 1 { + self.expression(args.next()?, ctx)? + } else { + // Infer the stride from the matrix type + let stride = if row_major { + columns as u32 + } else { + rows as u32 + }; + ctx.append_expression( + ir::Expression::Literal(ir::Literal::U32(stride)), + Span::UNDEFINED, + )? + }; + args.finish()?; + + crate::Expression::CooperativeLoad { + columns, + rows, + role, + data: crate::CooperativeData { + pointer, + stride, + row_major, + }, + } + } + "coopStore" | "coopStoreT" => { + let row_major = function.name.ends_with("T"); + + let mut args = ctx.prepare_args(arguments, 2, span); + let target = self.expression(args.next()?, ctx)?; + let pointer = self.expression(args.next()?, ctx)?; + let stride = if args.total_args > 2 { + self.expression(args.next()?, ctx)? + } else { + // Infer the stride from the matrix type + let stride = match *resolve_inner!(ctx, target) { + ir::TypeInner::CooperativeMatrix { columns, rows, .. } => { + if row_major { + columns as u32 + } else { + rows as u32 + } + } + _ => 0, + }; + ctx.append_expression( + ir::Expression::Literal(ir::Literal::U32(stride)), + Span::UNDEFINED, + )? + }; + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::CooperativeStore { + target, + data: crate::CooperativeData { + pointer, + stride, + row_major, + }, + }, + span, + ); + return Ok(None); + } + "coopMultiplyAdd" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let a = self.expression(args.next()?, ctx)?; + let b = self.expression(args.next()?, ctx)?; + let c = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::CooperativeMultiplyAdd { a, b, c } + } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } @@ -3955,6 +4064,25 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { _ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))), } } + ast::Type::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + } => { + let ty = self.resolve_ast_type(ty, ctx)?; + let scalar = match ctx.module.types[ty].inner { + ir::TypeInner::Scalar(s) => s, + _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))), + }; + ir::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + } + } ast::Type::Atomic(scalar) => scalar.to_inner_atomic(), ast::Type::Pointer { base, space } => { let base = self.resolve_ast_type(base, ctx)?; diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 345e9c4c486..da093a2f068 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -235,6 +235,13 @@ pub enum Type<'a> { ty: Handle>, ty_span: Span, }, + CooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ty: Handle>, + ty_span: Span, + role: crate::CooperativeRole, + }, Atomic(Scalar), Pointer { base: Handle>, @@ -385,6 +392,21 @@ pub enum ConstructorType<'a> { ty_span: Span, }, + /// A cooperative matrix construction base `coop_mat8x8(...)`. + PartialCooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + }, + + /// A full cooperative matrix construction `coop_mat8x8(...)`. + CooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ty: Handle>, + ty_span: Span, + role: crate::CooperativeRole, + }, + /// An array whose component type and size are inferred from the arguments: /// `array(3,4,5)`. PartialArray, @@ -465,6 +487,7 @@ pub enum Expression<'a> { Call { function: Ident<'a>, arguments: Vec>>, + result_ty: Option<(Handle>, Span)>, }, Index { base: Handle>, diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index d0a8033987b..ed87e371008 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -584,6 +584,18 @@ impl<'a> Lexer<'a> { }) } + pub(in crate::front::wgsl) fn next_cooperative_role( + &mut self, + ) -> Result<'a, crate::CooperativeRole> { + let (ident, span) = self.next_ident_with_span()?; + match ident { + "A" => Ok(crate::CooperativeRole::A), + "B" => Ok(crate::CooperativeRole::B), + "C" => Ok(crate::CooperativeRole::C), + _ => Err(Box::new(Error::UnknownAccess(span))), + } + } + pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<'a, ()> { self.expect(Token::Paren('(')) } diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index c01ba4de30f..50275c50785 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -658,6 +658,10 @@ impl Parser { ty_span: Span::UNDEFINED, })) } + "coop_mat8x8" => ast::ConstructorType::PartialCooperativeMatrix { + columns: crate::CooperativeSize::Eight, + rows: crate::CooperativeSize::Eight, + }, "array" => ast::ConstructorType::PartialArray, "atomic" | "binding_array" @@ -701,6 +705,19 @@ impl Parser { ty_span, })) } + ( + Token::Paren('<'), + ast::ConstructorType::PartialCooperativeMatrix { columns, rows }, + ) => { + let (ty, ty_span, role) = self.cooperative_scalar_and_role(lexer, ctx)?; + Ok(Some(ast::ConstructorType::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + })) + } (Token::Paren('<'), ast::ConstructorType::PartialArray) => { lexer.expect_generic_paren('<')?; let base = self.type_decl(lexer, ctx)?; @@ -783,6 +800,11 @@ impl Parser { } // everything else must be handled later, since they can be hidden by user-defined functions. _ => { + let result_ty = if lexer.peek().0 == Token::Paren('<') { + Some(self.singular_generic(lexer, ctx)?) + } else { + None + }; let arguments = self.arguments(lexer, ctx)?; ctx.unresolved.insert(ast::Dependency { ident: name, @@ -794,6 +816,7 @@ impl Parser { span: name_span, }, arguments, + result_ty, } } }; @@ -942,7 +965,7 @@ impl Parser { } else if let Token::Paren('(') = lexer.peek().0 { self.pop_rule_span(lexer); return self.function_call(lexer, word, span, ctx); - } else if word == "bitcast" { + } else if ["bitcast", "coopLoad"].contains(&word) { self.pop_rule_span(lexer); return self.function_call(lexer, word, span, ctx); } else { @@ -1437,6 +1460,22 @@ impl Parser { Ok((ty, span)) } + /// Parses ``, returning (T, span of T, R, span of R) + fn cooperative_scalar_and_role<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<'a, (Handle>, Span, crate::CooperativeRole)> { + lexer.expect_generic_paren('<')?; + let start = lexer.start_byte_offset(); + let ty = self.type_decl(lexer, ctx)?; + let ty_span = lexer.span_from(start); + lexer.expect(Token::Separator(','))?; + let role = lexer.next_cooperative_role()?; + lexer.expect_generic_paren('>')?; + Ok((ty, ty_span, role)) + } + fn matrix_with_type<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -1453,6 +1492,23 @@ impl Parser { }) } + fn cooperative_matrix_with_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ) -> Result<'a, ast::Type<'a>> { + let (ty, ty_span, role) = self.cooperative_scalar_and_role(lexer, ctx)?; + Ok(ast::Type::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + }) + } + fn type_decl_impl<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -1684,6 +1740,12 @@ impl Parser { ty: ctx.new_scalar(Scalar::F16), ty_span: Span::UNDEFINED, }, + "coop_mat8x8" => self.cooperative_matrix_with_type( + lexer, + ctx, + crate::CooperativeSize::Eight, + crate::CooperativeSize::Eight, + )?, "atomic" => { let scalar = lexer.next_scalar_generic()?; ast::Type::Atomic(scalar) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 257445952b8..e1a90619b72 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -437,6 +437,16 @@ impl From for u32 { } } +/// Number of components in a cooperative vector. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeSize { + Eight = 8, +} + /// Primitive type for a scalar. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -464,6 +474,18 @@ pub enum ScalarKind { AbstractFloat, } +/// Role of a cooperative variable in the equation "A * B + C" +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeRole { + A, + B, + C, +} + /// Characteristics of a scalar type. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -712,6 +734,14 @@ pub enum TypeInner { rows: VectorSize, scalar: Scalar, }, + /// Matrix that is cooperatively processed by all the threads + /// in an opaque mapping. + CooperativeMatrix { + columns: CooperativeSize, + rows: CooperativeSize, + scalar: Scalar, + role: CooperativeRole, + }, /// Atomic scalar. Atomic(Scalar), /// Pointer to another type. @@ -1391,6 +1421,16 @@ bitflags::bitflags! { } } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct CooperativeData { + pub pointer: Handle, + pub stride: Handle, + pub row_major: bool, +} + /// An expression that can be evaluated to obtain a value. /// /// This is a Single Static Assignment (SSA) scheme similar to SPIR-V. @@ -1733,6 +1773,20 @@ pub enum Expression { /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation /// [`SubgroupGather`]: Statement::SubgroupGather SubgroupOperationResult { ty: Handle }, + + /// Load a cooperative primitive from memory. + CooperativeLoad { + columns: CooperativeSize, + rows: CooperativeSize, + role: CooperativeRole, + data: CooperativeData, + }, + /// Compute `a * b + c` + CooperativeMultiplyAdd { + a: Handle, + b: Handle, + c: Handle, + }, } /// The value of the switch case. @@ -2174,6 +2228,11 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, + /// Store a cooperative primitive into memory. + CooperativeStore { + target: Handle, + data: CooperativeData, + }, } /// A function argument. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index f5a5d25ca87..c80789b2df7 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -584,6 +584,8 @@ pub enum ConstantEvaluatorError { "Expected reject and accept args. to be scalars of vectors of the same type, got something else", )] SelectAcceptRejectTypeMismatch, + #[error("Cooperative operations can't be constant")] + CooperativeOperation, } impl<'a> ConstantEvaluator<'a> { @@ -971,6 +973,9 @@ impl<'a> ConstantEvaluator<'a> { Expression::SubgroupOperationResult { .. } => { Err(ConstantEvaluatorError::SubgroupExpression) } + Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => { + Err(ConstantEvaluatorError::CooperativeOperation) + } } } diff --git a/naga/src/proc/layouter.rs b/naga/src/proc/layouter.rs index 204a523c91b..5165ac7a013 100644 --- a/naga/src/proc/layouter.rs +++ b/naga/src/proc/layouter.rs @@ -86,6 +86,12 @@ impl From for Alignment { } } +impl From for Alignment { + fn from(size: crate::CooperativeSize) -> Self { + Self(unsafe { NonZeroU32::new_unchecked(size as u32) }) + } +} + /// Size and alignment information for a type. #[derive(Clone, Copy, Debug, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -212,6 +218,19 @@ impl Layouter { alignment: Alignment::from(rows) * alignment, } } + Ti::CooperativeMatrix { + columns: _, + rows, + scalar, + role: _, + } => { + let alignment = Alignment::new(scalar.width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(rows) * alignment, + } + } Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { size, alignment: Alignment::ONE, diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index b29ccb054a3..6ffd8159303 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -43,7 +43,8 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } | S::ControlBarrier(_) - | S::MemoryBarrier(_)), + | S::MemoryBarrier(_) + | S::CooperativeStore { .. }), ) | None => block.push(S::Return { value: None }, Default::default()), } diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index c59d524f13e..fe3eb4b6268 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -115,6 +115,7 @@ impl crate::TypeInner { match *self { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar), Ti::Matrix { scalar, .. } => Some(scalar), + Ti::CooperativeMatrix { scalar, .. } => Some(scalar), _ => None, } } @@ -182,8 +183,8 @@ impl crate::TypeInner { pub fn is_atomic_pointer(&self, types: &crate::UniqueArena) -> bool { match *self { - crate::TypeInner::Pointer { base, .. } => match types[base].inner { - crate::TypeInner::Atomic { .. } => true, + Self::Pointer { base, .. } => match types[base].inner { + Self::Atomic { .. } => true, _ => false, }, _ => false, @@ -202,6 +203,12 @@ impl crate::TypeInner { rows, scalar, } => Some(super::Alignment::from(rows) * scalar.width as u32 * columns as u32), + Self::CooperativeMatrix { + columns, + rows, + scalar, + role: _, + } => Some(columns as u32 * rows as u32 * scalar.width as u32), Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN), Self::Array { base: _, @@ -361,6 +368,7 @@ impl crate::TypeInner { crate::TypeInner::Scalar(scalar) => Some((None, scalar)), crate::TypeInner::Vector { size, scalar } => Some((Some(size), scalar)), crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::Atomic(_) | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } @@ -385,7 +393,8 @@ impl crate::TypeInner { | crate::TypeInner::Matrix { scalar, .. } | crate::TypeInner::Atomic(scalar) => scalar.is_abstract(), crate::TypeInner::Array { base, .. } => types[base].inner.is_abstract(types), - crate::TypeInner::ValuePointer { .. } + crate::TypeInner::CooperativeMatrix { .. } + | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Pointer { .. } | crate::TypeInner::Struct { .. } | crate::TypeInner::Image { .. } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 79b4f95e106..1d469807744 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -143,6 +143,17 @@ impl Clone for TypeResolution { columns, scalar, }, + Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + } => Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }, Ti::Pointer { base, space } => Ti::Pointer { base, space }, Ti::ValuePointer { size, @@ -476,7 +487,7 @@ impl<'a> ResolveContext<'a> { None => Ti::Scalar(scalar), }), ref other => { - log::error!("Pointer type {other:?}"); + log::error!("Pointer {pointer:?} type {other:?}"); return Err(ResolveError::InvalidPointer(pointer)); } }, @@ -587,6 +598,20 @@ impl<'a> ResolveContext<'a> { (&Ti::Scalar { .. }, _) => res_right.clone(), (_, &Ti::Scalar { .. }) => res_left.clone(), (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(), + ( + &Ti::CooperativeMatrix { + columns: _, + rows, + scalar, + role, + }, + &Ti::CooperativeMatrix { columns, .. }, + ) => TypeResolution::Value(Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }), (tl, tr) => { return Err(ResolveError::IncompatibleOperands(format!( "{tl:?} * {tr:?}" @@ -776,6 +801,25 @@ impl<'a> ResolveContext<'a> { scalar: crate::Scalar::U32, size: crate::VectorSize::Quad, }), + crate::Expression::CooperativeLoad { + columns, + rows, + role, + ref data, + } => { + let scalar = past(data.pointer)? + .inner_with(types) + .pointer_base_type() + .and_then(|tr| tr.inner_with(types).scalar()) + .ok_or(ResolveError::InvalidPointer(data.pointer))?; + TypeResolution::Value(Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }) + } + crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(), }) } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 95ae40dcdb4..47a43731b81 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -29,6 +29,7 @@ bitflags::bitflags! { const WORK_GROUP_BARRIER = 0x1; const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 }; const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 }; + const COOP_OPS = 0x8; } } @@ -822,6 +823,14 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::CooperativeLoad { ref data, .. } => Uniformity { + non_uniform_result: self.add_ref(data.pointer).or(self.add_ref(data.stride)), + requirements: UniformityRequirements::COOP_OPS, + }, + E::CooperativeMultiplyAdd { a, b, c } => Uniformity { + non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))), + requirements: UniformityRequirements::COOP_OPS, + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -1151,6 +1160,16 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::CooperativeStore { target, ref data } => FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .add_ref(target) + .or(self.add_ref_impl(data.pointer, GlobalUse::WRITE)) + .or(self.add_ref(data.stride)), + requirements: UniformityRequirements::COOP_OPS, + }, + exit: ExitFlags::empty(), + }, }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 68023b5bf01..8f39689642c 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -141,6 +141,8 @@ pub enum ExpressionError { Literal(#[from] LiteralError), #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), + #[error("Invalid operand for cooperative op")] + InvalidCooperativeOperand(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -788,7 +790,9 @@ impl super::Validator { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, - Ti::Matrix { .. } => left_inner == right_inner, + Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => { + left_inner == right_inner + } _ => false, }, Bo::Divide | Bo::Modulo => match *left_inner { @@ -818,7 +822,7 @@ impl super::Validator { scalar: scalar2, .. }, ) => scalar1 == scalar2, - // Scalar/matrix. + // Scalar * matrix. ( &Ti::Scalar(Sc { kind: Sk::Float, .. @@ -831,7 +835,7 @@ impl super::Validator { kind: Sk::Float, .. }), ) => true, - // Vector/vector. + // Vector * vector. ( &Ti::Vector { size: size1, @@ -864,9 +868,30 @@ impl super::Validator { }, &Ti::Matrix { rows, .. }, ) => size == rows, + // Matrix * matrix. (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { columns == rows } + // Coop matrix * coop matrix. + ( + &Ti::CooperativeMatrix { + columns, + scalar: scalar1, + role: role1, + .. + }, + &Ti::CooperativeMatrix { + rows, + scalar: scalar2, + role: role2, + .. + }, + ) => columns == rows && scalar1 == scalar2 && role1 == role2, + // Scalar * coop matrix. + (&Ti::Scalar(s1), &Ti::CooperativeMatrix { scalar: s2, .. }) + | (&Ti::CooperativeMatrix { scalar: s1, .. }, &Ti::Scalar(s2)) => { + s1 == s2 + } _ => false, }; let left_width = left_inner.scalar_width().unwrap_or(0); @@ -1230,6 +1255,33 @@ impl super::Validator { } }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, + E::CooperativeLoad { ref data, .. } => { + if resolver[data.pointer] + .pointer_base_type() + .and_then(|tr| tr.inner_with(&module.types).scalar()) + .is_none() + { + return Err(ExpressionError::InvalidPointerType(data.pointer)); + } + ShaderStages::COMPUTE + } + E::CooperativeMultiplyAdd { a, b, c } => { + let roles = [ + crate::CooperativeRole::A, + crate::CooperativeRole::B, + crate::CooperativeRole::C, + ]; + for (operand, expected_role) in [a, b, c].into_iter().zip(roles) { + match resolver[operand] { + Ti::CooperativeMatrix { role, .. } if role == expected_role => {} + ref other => { + log::error!("{expected_role:?} operand type: {other:?}"); + return Err(ExpressionError::InvalidCooperativeOperand(a)); + } + } + } + ShaderStages::COMPUTE + } }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index dc19e191764..5c45aa44079 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1,6 +1,5 @@ use alloc::{format, string::String}; -use super::validate_atomic_compare_exchange_struct; use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, ExpressionError, FunctionInfo, ModuleInfo, @@ -213,6 +212,10 @@ pub enum FunctionError { WorkgroupUniformLoadInvalidPointer(Handle), #[error("Subgroup operation is invalid")] InvalidSubgroup(#[from] SubgroupError), + #[error("Invalid target type for a cooperative store")] + InvalidCooperativeStoreTarget(Handle), + #[error("Cooperative load/store data pointer has invalid type")] + InvalidCooperativeDataPointer(Handle), #[error("Emit statement should not cover \"result\" expressions like {0:?}")] EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] @@ -576,7 +579,7 @@ impl super::Validator { .with_span_handle(result, context.expressions) .into_other()); }; - if !validate_atomic_compare_exchange_struct( + if !super::validate_atomic_compare_exchange_struct( context.types, members, |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(pointer_scalar), @@ -796,7 +799,9 @@ impl super::Validator { | Ex::As { .. } | Ex::ArrayLength(_) | Ex::RayQueryGetIntersection { .. } - | Ex::RayQueryVertexPositions { .. } => { + | Ex::RayQueryVertexPositions { .. } + | Ex::CooperativeLoad { .. } + | Ex::CooperativeMultiplyAdd { .. } => { self.emit_expression(handle, context)? } Ex::CallResult(_) @@ -1618,6 +1623,37 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } + S::CooperativeStore { target, ref data } => { + stages &= super::ShaderStages::COMPUTE; + + let target_scalar = + match *context.resolve_type_inner(target, &self.valid_expression_set)? { + Ti::CooperativeMatrix { scalar, .. } => scalar, + ref other => { + log::error!("Target operand type: {other:?}"); + return Err(FunctionError::InvalidCooperativeStoreTarget(target) + .with_span_handle(target, context.expressions)); + } + }; + + let ptr_ty = context.resolve_pointer_type(data.pointer); + let ptr_scalar = ptr_ty + .pointer_base_type() + .and_then(|tr| tr.inner_with(context.types).scalar()); + if ptr_scalar != Some(target_scalar) { + return Err(FunctionError::InvalidCooperativeDataPointer(data.pointer) + .with_span_handle(data.pointer, context.expressions)); + } + + let ptr_space = ptr_ty.pointer_space().unwrap_or(AddressSpace::Handle); + if !ptr_space.access().contains(crate::StorageAccess::STORE) { + return Err(FunctionError::InvalidStorePointer(data.pointer) + .with_span_static( + context.expressions.get_span(data.pointer), + "writing to this location is not permitted", + )); + } + } } } Ok(BlockInfo { stages }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e8a69013434..4df550dbc60 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -379,6 +379,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Atomic { .. } | crate::TypeInner::Image { .. } @@ -647,6 +648,12 @@ impl super::Validator { } => { handle.check_dep(query)?; } + crate::Expression::CooperativeLoad { ref data, .. } => { + handle.check_dep(data.pointer)?.check_dep(data.stride)?; + } + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + handle.check_dep(a)?.check_dep(b)?.check_dep(c)?; + } } Ok(()) } @@ -835,6 +842,12 @@ impl super::Validator { validate_expr(result)?; Ok(()) } + crate::Statement::CooperativeStore { target, ref data } => { + validate_expr(target)?; + validate_expr(data.pointer)?; + validate_expr(data.stride)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 426b3d637d7..fa0fdb0d393 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -186,6 +186,8 @@ bitflags::bitflags! { /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store /// `f16`-precision values in `f32`s. const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28; + /// Support for cooperative matrix types and operations + const COOPERATIVE_MATRIX = 1 << 29; } } @@ -451,6 +453,7 @@ impl crate::TypeInner { Self::Scalar { .. } | Self::Vector { .. } | Self::Matrix { .. } + | Self::CooperativeMatrix { .. } | Self::Array { size: crate::ArraySize::Constant(_), .. diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index e8b83ff08f3..93cdae34e16 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -415,6 +415,27 @@ impl super::Validator { type_info.push_constant_compatibility = push_constant_compatibility; type_info } + Ti::CooperativeMatrix { + columns: _, + rows: _, + scalar, + role: _, + } => { + self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; + if scalar.kind != crate::ScalarKind::Float || scalar.width != 4 { + return Err(TypeError::MatrixElementNotFloat); + } + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED, + Alignment::from_width(scalar.width), + ) + } Ti::Atomic(scalar) => { match scalar { crate::Scalar { diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml new file mode 100644 index 00000000000..7d20269efc2 --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -0,0 +1,9 @@ +targets = "IR | SPIRV | METAL | WGSL" +god_mode = true + +[spv] +debug = true +version = [1, 4] + +[msl] +lang_version = [2, 3] diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl new file mode 100644 index 00000000000..06afb5cdf4e --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -0,0 +1,12 @@ +var a: coop_mat8x8; +var b: coop_mat8x8; +@group(0) @binding(0) +var ext: array; + +@compute @workgroup_size(8, 8, 1) +fn main() { + var c = coopLoad>(&ext[4]); + var d = coopMultiplyAdd(a, b, c); + coopStore(d, &ext[0]); + c = d; +} diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron new file mode 100644 index 00000000000..d31e45cd6f9 --- /dev/null +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -0,0 +1,227 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: A, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: B, + ), + ), + ( + name: None, + inner: Array( + base: 0, + size: Dynamic, + stride: 4, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: C, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("a"), + space: Private, + binding: None, + ty: 1, + init: None, + ), + ( + name: Some("b"), + space: Private, + binding: None, + ty: 2, + init: None, + ), + ( + name: Some("ext"), + space: Storage( + access: ("LOAD | STORE"), + ), + binding: Some(( + group: 0, + binding: 0, + )), + ty: 3, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (8, 8, 1), + workgroup_size_overrides: None, + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("c"), + ty: 4, + init: None, + ), + ( + name: Some("d"), + ty: 4, + init: None, + ), + ], + expressions: [ + GlobalVariable(2), + AccessIndex( + base: 0, + index: 4, + ), + Literal(U32(8)), + CooperativeLoad( + columns: Eight, + rows: Eight, + role: C, + data: ( + pointer: 1, + stride: 2, + row_major: false, + ), + ), + LocalVariable(0), + GlobalVariable(0), + Load( + pointer: 5, + ), + GlobalVariable(1), + Load( + pointer: 7, + ), + Load( + pointer: 4, + ), + CooperativeMultiplyAdd( + a: 6, + b: 8, + c: 9, + ), + LocalVariable(1), + Load( + pointer: 11, + ), + GlobalVariable(2), + AccessIndex( + base: 13, + index: 0, + ), + Literal(U32(8)), + Load( + pointer: 11, + ), + ], + named_expressions: {}, + body: [ + Emit(( + start: 1, + end: 2, + )), + Emit(( + start: 3, + end: 4, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 6, + end: 7, + )), + Emit(( + start: 8, + end: 11, + )), + Store( + pointer: 11, + value: 10, + ), + Emit(( + start: 12, + end: 13, + )), + Emit(( + start: 14, + end: 15, + )), + CooperativeStore( + target: 12, + data: ( + pointer: 14, + stride: 15, + row_major: false, + ), + ), + Emit(( + start: 16, + end: 17, + )), + Store( + pointer: 4, + value: 16, + ), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron new file mode 100644 index 00000000000..d31e45cd6f9 --- /dev/null +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -0,0 +1,227 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: A, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: B, + ), + ), + ( + name: None, + inner: Array( + base: 0, + size: Dynamic, + stride: 4, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: C, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("a"), + space: Private, + binding: None, + ty: 1, + init: None, + ), + ( + name: Some("b"), + space: Private, + binding: None, + ty: 2, + init: None, + ), + ( + name: Some("ext"), + space: Storage( + access: ("LOAD | STORE"), + ), + binding: Some(( + group: 0, + binding: 0, + )), + ty: 3, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (8, 8, 1), + workgroup_size_overrides: None, + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("c"), + ty: 4, + init: None, + ), + ( + name: Some("d"), + ty: 4, + init: None, + ), + ], + expressions: [ + GlobalVariable(2), + AccessIndex( + base: 0, + index: 4, + ), + Literal(U32(8)), + CooperativeLoad( + columns: Eight, + rows: Eight, + role: C, + data: ( + pointer: 1, + stride: 2, + row_major: false, + ), + ), + LocalVariable(0), + GlobalVariable(0), + Load( + pointer: 5, + ), + GlobalVariable(1), + Load( + pointer: 7, + ), + Load( + pointer: 4, + ), + CooperativeMultiplyAdd( + a: 6, + b: 8, + c: 9, + ), + LocalVariable(1), + Load( + pointer: 11, + ), + GlobalVariable(2), + AccessIndex( + base: 13, + index: 0, + ), + Literal(U32(8)), + Load( + pointer: 11, + ), + ], + named_expressions: {}, + body: [ + Emit(( + start: 1, + end: 2, + )), + Emit(( + start: 3, + end: 4, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 6, + end: 7, + )), + Emit(( + start: 8, + end: 11, + )), + Store( + pointer: 11, + value: 10, + ), + Emit(( + start: 12, + end: 13, + )), + Emit(( + start: 14, + end: 15, + )), + CooperativeStore( + target: 12, + data: ( + pointer: 14, + stride: 15, + row_major: false, + ), + ), + Emit(( + start: 16, + end: 17, + )), + Store( + pointer: 4, + value: 16, + ), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl new file mode 100644 index 00000000000..604ec4a169a --- /dev/null +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -0,0 +1,43 @@ +// language: metal2.3 +#include +#include + +using metal::uint; + +struct _mslBufferSizes { + uint size2; +}; + +typedef float type_3[1]; +metal::simdgroup_float8x8 NagaCooperativeLoad(const device float* ptr, int stride, bool is_row_major) { + metal::simdgroup_float8x8 m; + simdgroup_load(m, ptr, stride, 0, is_row_major); + return m; +} + +metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const thread metal::simdgroup_float8x8& a, const thread metal::simdgroup_float8x8& b, const thread metal::simdgroup_float8x8& c) { + metal::simdgroup_float8x8 d; + simdgroup_multiply_accumulate(d,a,b,c); + return d; +} + + +kernel void main_( + device type_3& ext [[user(fake0)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] +) { + metal::simdgroup_float8x8 a = {}; + metal::simdgroup_float8x8 b = {}; + metal::simdgroup_float8x8 c = {}; + metal::simdgroup_float8x8 d = {}; + c = NagaCooperativeLoad(&ext[4], 8u, false); + metal::simdgroup_float8x8 _e6 = a; + metal::simdgroup_float8x8 _e8 = b; + metal::simdgroup_float8x8 _e9 = c; + d = NagaCooperativeMultiplyAdd(_e6, _e8, _e9); + metal::simdgroup_float8x8 _e12 = d; + simdgroup_store(_e12, &ext[0], 8u); + metal::simdgroup_float8x8 _e16 = d; + c = _e16; + return; +} diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm new file mode 100644 index 00000000000..f2c9b5ceb53 --- /dev/null +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -0,0 +1,99 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 46 +OpCapability Shader +OpCapability CooperativeMatrixKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %25 "main" %15 %18 %21 +OpExecutionMode %25 LocalSize 8 8 1 +%3 = OpString "cooperative-matrix.wgsl" +OpSource Unknown 0 %3 "var a: coop_mat8x8; +var b: coop_mat8x8; +@group(0) @binding(0) +var ext: array; + +@compute @workgroup_size(8, 8, 1) +fn main() { + var c = coopLoad>(&ext[4]); + var d = coopMultiplyAdd(a, b, c); + coopStore(d, &ext[0]); + c = d; +} +" +OpName %15 "a" +OpName %18 "b" +OpName %21 "ext" +OpName %25 "main" +OpName %29 "c" +OpName %32 "d" +OpDecorate %12 ArrayStride 4 +OpDecorate %21 DescriptorSet 0 +OpDecorate %21 Binding 0 +OpDecorate %22 Block +OpMemberDecorate %22 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeFloat 32 +%7 = OpTypeInt 32 0 +%6 = OpConstant %7 3 +%8 = OpConstant %7 8 +%9 = OpConstant %7 0 +%5 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %9 +%11 = OpConstant %7 1 +%10 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %11 +%12 = OpTypeRuntimeArray %4 +%14 = OpConstant %7 2 +%13 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %14 +%16 = OpTypePointer Private %5 +%17 = OpConstantNull %5 +%15 = OpVariable %16 Private %17 +%19 = OpTypePointer Private %10 +%20 = OpConstantNull %10 +%18 = OpVariable %19 Private %20 +%22 = OpTypeStruct %12 +%23 = OpTypePointer StorageBuffer %22 +%21 = OpVariable %23 StorageBuffer +%26 = OpTypeFunction %2 +%27 = OpTypePointer StorageBuffer %12 +%30 = OpTypePointer Function %13 +%31 = OpConstantNull %13 +%33 = OpConstantNull %13 +%35 = OpTypePointer StorageBuffer %4 +%36 = OpConstant %7 4 +%25 = OpFunction %2 None %26 +%24 = OpLabel +%29 = OpVariable %30 Function %31 +%32 = OpVariable %30 Function %33 +%28 = OpAccessChain %27 %21 %9 +OpBranch %34 +%34 = OpLabel +OpLine %3 8 44 +OpLine %3 8 13 +%37 = OpAccessChain %35 %28 %36 +%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8 +OpLine %3 8 5 +OpStore %29 %38 +OpLine %3 9 29 +%39 = OpLoad %5 %15 +OpLine %3 9 13 +%40 = OpLoad %10 %18 +%41 = OpLoad %13 %29 +%42 = OpCooperativeMatrixMulAddKHR %13 %39 %40 %41 +OpLine %3 9 5 +OpStore %32 %42 +OpLine %3 1 1 +%43 = OpLoad %13 %32 +OpLine %3 10 19 +OpLine %3 10 5 +%44 = OpAccessChain %35 %28 %9 +OpCooperativeMatrixStoreKHR %44 %43 %11 %8 +OpLine %3 1 1 +%45 = OpLoad %13 %32 +OpLine %3 11 5 +OpStore %29 %45 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl new file mode 100644 index 00000000000..183dc84ad74 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl @@ -0,0 +1,21 @@ +var a: coop_mat8x8; +var b: coop_mat8x8; +@group(0) @binding(0) +var ext: array; + +@compute @workgroup_size(8, 8, 1) +fn main() { + var c: coop_mat8x8; + var d: coop_mat8x8; + + c = coopLoad>((&ext[4]), 8u); + let _e6 = a; + let _e8 = b; + let _e9 = c; + d = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d; + coopStore(_e12, (&ext[0]), 8u); + let _e16 = d; + c = _e16; + return; +}