Skip to content

Commit

Permalink
[naga] Implement atomicCompareExchangeWeak for MSL backend
Browse files Browse the repository at this point in the history
* See issue gfx-rs#5257
  • Loading branch information
AsherJingkongChen committed Oct 10, 2024
1 parent d9178a1 commit 76c7deb
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 47 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).

- Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149)

#### Metal

- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6264](https://github.com/gfx-rs/wgpu/pull/6264)

### Bug Fixes

- Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123)
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ pub enum Error {
UnsupportedAttribute(String),
#[error("function '{0}' is not supported for target MSL version")]
UnsupportedFunction(String),
#[error("scalar {0:?} is not supported for target MSL version")]
UnsupportedScalar(crate::Scalar),
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
UnsupportedWriteableStorageBuffer,
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
Expand Down
130 changes: 85 additions & 45 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const ATOMIC_COMP_EXCH_FUNCTION_KEY: &str = "naga_atomic_compare_exchange_weak";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

Expand Down Expand Up @@ -1279,42 +1280,6 @@ impl<W: Write> Writer<W> {
Ok(())
}

fn put_atomic_operation(
&mut self,
pointer: Handle<crate::Expression>,
key: &str,
value: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

Ok(())
}

/// Emit code for the arithmetic expression of the dot product.
///
fn put_dot_product(
Expand Down Expand Up @@ -3182,24 +3147,61 @@ impl<W: Write> Writer<W> {
value,
result,
} => {
let context = &context.expression;

// This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
// `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
// `Some`, we are not operating on a 64-bit value, and that if we are
// operating on a 64-bit value, `result` is `None`.
write!(self.out, "{level}")?;
let fun_str = if let Some(result) = result {
let fun_key = if let Some(result) = result {
let res_name = Baked(result).to_string();
self.start_baking_expression(result, &context.expression, &res_name)?;
self.start_baking_expression(result, context, &res_name)?;
self.named_expressions.insert(result, res_name);
fun.to_msl()?
} else if context.expression.resolve_type(value).scalar_width() == Some(8) {
} else if context.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl_64_bit()?
} else {
fun.to_msl()?
};

self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;

// Put the extra argument if provided.
if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
} else {
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
}

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

// Done
writeln!(self.out, ";")?;
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
Expand Down Expand Up @@ -3827,7 +3829,47 @@ impl<W: Write> Writer<W> {
}}"
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
let crate::Scalar { kind, width } = scalar;
let arg_type_name = match width {
1 => "bool",
4 => match kind {
crate::ScalarKind::Sint => "int",
crate::ScalarKind::Uint => "uint",
crate::ScalarKind::Float => "float",
_ => return Err(Error::UnsupportedScalar(scalar)),
},
_ => return Err(Error::UnsupportedScalar(scalar)),
};

let called_func_name = "atomic_compare_exchange_weak_explicit";
let defined_func_key = ATOMIC_COMP_EXCH_FUNCTION_KEY;
let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
writeln!(self.out, "namespace {NAMESPACE} {{")?;

for address_space_name in ["device", "threadgroup"] {
writeln!(
self.out,
" \
template <typename A>
{struct_name} atomic_{defined_func_key}_explicit(
volatile {address_space_name} A *atomic_ptr,
{arg_type_name} cmp,
{arg_type_name} v
) {{
bool swapped = {NAMESPACE}::{called_func_name}(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return {struct_name}{{cmp, swapped}};
}}"
)?;
}

writeln!(self.out, "}}")?;
}
}
}

Expand Down Expand Up @@ -6075,9 +6117,7 @@ impl crate::AtomicFunction {
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION_KEY,
})
}

Expand Down
164 changes: 164 additions & 0 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_2 {
metal::atomic_int inner[128];
};
struct type_4 {
metal::atomic_uint inner[128];
};
struct _atomic_compare_exchange_resultSint4_ {
int old_value;
bool exchanged;
};
struct _atomic_compare_exchange_resultUint4_ {
uint old_value;
bool exchanged;
};

namespace metal {
template <typename A>
_atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile device A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile threadgroup A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
}

namespace metal {
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile device A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile threadgroup A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
}
constant uint SIZE = 128u;

kernel void test_atomic_compare_exchange_i32_(
device type_2& arr_i32_ [[user(fake0)]]
) {
uint i = 0u;
int old = {};
bool exchanged = {};
bool loop_init = true;
while(true) {
if (!loop_init) {
uint _e27 = i;
i = _e27 + 1u;
}
loop_init = false;
uint _e2 = i;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i;
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
old = _e8;
exchanged = false;
while(true) {
bool _e12 = exchanged;
if (!(_e12)) {
} else {
break;
}
{
int _e14 = old;
int new_ = as_type<int>(as_type<float>(_e14) + 1.0);
uint _e20 = i;
int _e22 = old;
_atomic_compare_exchange_resultSint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
old = _e23.old_value;
exchanged = _e23.exchanged;
}
}
}
}
return;
}


kernel void test_atomic_compare_exchange_u32_(
device type_4& arr_u32_ [[user(fake0)]]
) {
uint i_1 = 0u;
uint old_1 = {};
bool exchanged_1 = {};
bool loop_init_1 = true;
while(true) {
if (!loop_init_1) {
uint _e27 = i_1;
i_1 = _e27 + 1u;
}
loop_init_1 = false;
uint _e2 = i_1;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i_1;
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
old_1 = _e8;
exchanged_1 = false;
while(true) {
bool _e12 = exchanged_1;
if (!(_e12)) {
} else {
break;
}
{
uint _e14 = old_1;
uint new_1 = as_type<uint>(as_type<float>(_e14) + 1.0);
uint _e20 = i_1;
uint _e22 = old_1;
_atomic_compare_exchange_resultUint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
}
}
}
return;
}
Loading

0 comments on commit 76c7deb

Please sign in to comment.