Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add 32-bit floating-point atomics (SHADER_FLOAT32_ATOMIC) #6234

Open
wants to merge 17 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]

- Return submission index in `map_async` and `on_submitted_work_done` to track down completion of async callbacks. By @eliemichel in [#6360](https://github.com/gfx-rs/wgpu/pull/6360).

#### Vulkan

- Allow using some 32-bit floating-point atomic operations (load, store, add, sub, exchange) in shaders. It requires the extension `VK_EXT_shader_atomic_float`. By @AsherJingkongChen in [#6234](https://github.com/gfx-rs/wgpu/pull/6234).

#### Metal

- Allow using some 32-bit floating-point atomic operations (load, store, add, sub, exchange) in shaders. It requires Metal 3.0+ with Apple 7, 8, 9 or Mac 2. By @AsherJingkongChen in [#6234](https://github.com/gfx-rs/wgpu/pull/6234).

### Changes

#### Naga
Expand Down
192 changes: 120 additions & 72 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2733,62 +2733,115 @@ impl<'w> BlockContext<'w> {
let value_id = self.cached[value];
let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);

let crate::TypeInner::Scalar(scalar) = *value_inner else {
return Err(Error::FeatureNotImplemented(
"Atomics with non-scalar values",
));
};

let instruction = match *fun {
crate::AtomicFunction::Add => Instruction::atomic_binary(
spirv::Op::AtomicIAdd,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::Subtract => Instruction::atomic_binary(
spirv::Op::AtomicISub,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::And => Instruction::atomic_binary(
spirv::Op::AtomicAnd,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
spirv::Op::AtomicOr,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
spirv::Op::AtomicXor,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
),
crate::AtomicFunction::Add => {
let spirv_op = match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
spirv::Op::AtomicIAdd
}
crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::Subtract => {
let (spirv_op, value_id) = match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
(spirv::Op::AtomicISub, value_id)
}
crate::ScalarKind::Float => {
// HACK: SPIR-V doesn't have a atomic subtraction,
// so we add the negated value instead.
let neg_result_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::FNegate,
result_type_id,
neg_result_id,
value_id,
));
(spirv::Op::AtomicFAddEXT, neg_result_id)
}
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::And => {
let spirv_op = match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
spirv::Op::AtomicAnd
}
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::InclusiveOr => {
let spirv_op = match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
spirv::Op::AtomicOr
}
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::ExclusiveOr => {
let spirv_op = match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
spirv::Op::AtomicXor
}
_ => unimplemented!(),
};
Instruction::atomic_binary(
spirv_op,
result_type_id,
id,
pointer_id,
scope_constant_id,
semantics_id,
value_id,
)
}
crate::AtomicFunction::Min => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
}) => spirv::Op::AtomicSMin,
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Uint,
width: _,
}) => spirv::Op::AtomicUMin,
let spirv_op = match scalar.kind {
crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
_ => unimplemented!(),
};
Instruction::atomic_binary(
Expand All @@ -2802,15 +2855,9 @@ impl<'w> BlockContext<'w> {
)
}
crate::AtomicFunction::Max => {
let spirv_op = match *value_inner {
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
}) => spirv::Op::AtomicSMax,
crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Uint,
width: _,
}) => spirv::Op::AtomicUMax,
let spirv_op = match scalar.kind {
crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
_ => unimplemented!(),
};
Instruction::atomic_binary(
Expand All @@ -2835,20 +2882,21 @@ impl<'w> BlockContext<'w> {
)
}
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
let scalar_type_id = match *value_inner {
crate::TypeInner::Scalar(scalar) => {
self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(scalar),
)))
}
_ => unimplemented!(),
};
let scalar_type_id = self.get_type_id(LookupType::Local(
LocalType::Numeric(NumericType::Scalar(scalar)),
));
let bool_type_id = self.get_type_id(LookupType::Local(
LocalType::Numeric(NumericType::Scalar(crate::Scalar::BOOL)),
));

let cas_result_id = self.gen_id();
let equality_result_id = self.gen_id();
let equality_operator = match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
spirv::Op::IEqual
}
_ => unimplemented!(),
};
let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
cas_instr.set_type(scalar_type_id);
cas_instr.set_result(cas_result_id);
Expand All @@ -2860,7 +2908,7 @@ impl<'w> BlockContext<'w> {
cas_instr.add_operand(self.cached[cmp]);
block.body.push(cas_instr);
block.body.push(Instruction::binary(
spirv::Op::IEqual,
equality_operator,
bool_type_id,
equality_result_id,
cas_result_id,
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,16 @@ impl Writer {
crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
}
crate::TypeInner::Atomic(crate::Scalar {
width: 4,
kind: crate::ScalarKind::Float,
}) => {
self.require_any(
"32 bit floating-point atomics",
&[spirv::Capability::AtomicFloat32AddEXT],
)?;
self.use_extension("SPV_EXT_shader_atomic_float_add");
}
_ => {}
}
Ok(())
Expand Down
10 changes: 7 additions & 3 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
spirv::Capability::Int64,
spirv::Capability::Int64Atomics,
spirv::Capability::Float16,
spirv::Capability::AtomicFloat32AddEXT,
spirv::Capability::Float64,
spirv::Capability::Geometry,
spirv::Capability::MultiView,
Expand All @@ -77,6 +78,7 @@ pub const SUPPORTED_EXTENSIONS: &[&str] = &[
"SPV_KHR_storage_buffer_storage_class",
"SPV_KHR_vulkan_memory_model",
"SPV_KHR_multiview",
"SPV_EXT_shader_atomic_float_add",
];
pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"];

Expand Down Expand Up @@ -4230,7 +4232,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
| Op::AtomicUMax
| Op::AtomicAnd
| Op::AtomicOr
| Op::AtomicXor => self.parse_atomic_expr_with_value(
| Op::AtomicXor
| Op::AtomicFAddEXT => self.parse_atomic_expr_with_value(
inst,
&mut emitter,
ctx,
Expand All @@ -4239,15 +4242,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
body_idx,
match inst.op {
Op::AtomicExchange => crate::AtomicFunction::Exchange { compare: None },
Op::AtomicIAdd => crate::AtomicFunction::Add,
Op::AtomicIAdd | Op::AtomicFAddEXT => crate::AtomicFunction::Add,
Op::AtomicISub => crate::AtomicFunction::Subtract,
Op::AtomicSMin => crate::AtomicFunction::Min,
Op::AtomicUMin => crate::AtomicFunction::Min,
Op::AtomicSMax => crate::AtomicFunction::Max,
Op::AtomicUMax => crate::AtomicFunction::Max,
Op::AtomicAnd => crate::AtomicFunction::And,
Op::AtomicOr => crate::AtomicFunction::InclusiveOr,
_ => crate::AtomicFunction::ExclusiveOr,
Op::AtomicXor => crate::AtomicFunction::ExclusiveOr,
_ => unreachable!(),
},
)?,

Expand Down
17 changes: 16 additions & 1 deletion naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1937,14 +1937,18 @@ pub enum Statement {
/// If [`SHADER_INT64_ATOMIC_MIN_MAX`] or [`SHADER_INT64_ATOMIC_ALL_OPS`] are
/// enabled, this may also be [`I64`] or [`U64`].
///
/// If [`SHADER_FLOAT32_ATOMIC`] is enabled, this may be [`F32`].
///
/// [`Pointer`]: TypeInner::Pointer
/// [`Atomic`]: TypeInner::Atomic
/// [`I32`]: Scalar::I32
/// [`U32`]: Scalar::U32
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
/// [`SHADER_FLOAT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLOAT32_ATOMIC
/// [`I64`]: Scalar::I64
/// [`U64`]: Scalar::U64
/// [`F32`]: Scalar::F32
pointer: Handle<Expression>,

/// Function to run on the atomic value.
Expand All @@ -1955,14 +1959,24 @@ pub enum Statement {
/// value here.
///
/// - The [`SHADER_INT64_ATOMIC_MIN_MAX`] capability allows
/// [`AtomicFunction::Min`] and [`AtomicFunction::Max`] here.
/// [`AtomicFunction::Min`] and [`AtomicFunction::Max`]
/// in the [`Storage`] address space here.
///
/// - If neither of those capabilities are present, then 64-bit scalar
/// atomics are not allowed.
///
/// If [`pointer`] refers to a 32-bit floating-point atomic value, then:
///
/// - The [`SHADER_FLOAT32_ATOMIC`] capability allows [`AtomicFunction::Add`],
/// [`AtomicFunction::Subtract`], and [`AtomicFunction::Exchange { compare: None }`]
/// in the [`Storage`] address space here.
///
/// [`AtomicFunction::Exchange { compare: None }`]: AtomicFunction::Exchange
/// [`pointer`]: Statement::Atomic::pointer
/// [`Storage`]: AddressSpace::Storage
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
/// [`SHADER_FLOAT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLOAT32_ATOMIC
fun: AtomicFunction,

/// Value to use in the function.
Expand All @@ -1989,6 +2003,7 @@ pub enum Statement {
/// [`Exchange { compare: None }`]: AtomicFunction::Exchange
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
/// [`SHADER_FLOAT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLOAT32_ATOMIC
result: Option<Handle<Expression>>,
},
/// Load uniformly from a uniform pointer in the workgroup address space.
Expand Down
Loading