Skip to content

Commit

Permalink
[naga msl-out] Implement atomicCompareExchangeWeak for MSL backend
Browse files Browse the repository at this point in the history
* Fix issue gfx-rs#5257
* See PR gfx-rs/naga#5675, gfx-rs/naga#6265
  • Loading branch information
AsherJingkongChen committed Oct 10, 2024
1 parent 3693da2 commit 64f1eac
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 52 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).

- Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149)

#### Metal

- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6265](https://github.com/gfx-rs/wgpu/pull/6265)

### Bug Fixes

- Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123)
Expand Down Expand Up @@ -123,6 +127,9 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- Document `wgpu_hal` bounds-checking promises, and adapt `wgpu_core`'s lazy initialization logic to the slightly weaker-than-expected guarantees. By @jimblandy in [#6201](https://github.com/gfx-rs/wgpu/pull/6201)
- Raise validation error instead of panicking in `{Render,Compute}Pipeline::get_bind_group_layout` on native / WebGL. By @bgr360 in [#6280](https://github.com/gfx-rs/wgpu/pull/6280).
- **BREAKING**: Remove the last exposed C symbols in project, located in `wgpu_core::render::bundle::bundle_ffi`, to allow multiple versions of WGPU to compile together. By @ErichDonGubler in [#6272](https://github.com/gfx-rs/wgpu/pull/6272).
- Call `flush_mapped_ranges` when unmapping write-mapped buffers. By @teoxoy in [#6089](https://github.com/gfx-rs/wgpu/pull/6089).
- When mapping buffers for reading, mark buffers as initialized only when they have `MAP_WRITE` usage. By @teoxoy in [#6178](https://github.com/gfx-rs/wgpu/pull/6178).
- Add a separate pipeline constants error. By @teoxoy in [#6094](https://github.com/gfx-rs/wgpu/pull/6094).

#### GLES / OpenGL

Expand All @@ -143,6 +150,16 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- `wgpu_hal::gles::Adapter::new_external` now requires the context to be current when dropping the adapter and related objects. By @Imberflur in [#6114](https://github.com/gfx-rs/wgpu/pull/6114).
- Reduce the amount of debug and trace logs emitted by wgpu-core and wgpu-hal. By @nical in [#6065](https://github.com/gfx-rs/wgpu/issues/6065)
- `Rg11b10Float` is renamed to `Rg11b10Ufloat`. By @sagudev in [#6108](https://github.com/gfx-rs/wgpu/pull/6108)
- Invalidate the device when we encounter driver-induced device loss or on unexpected errors. By @teoxoy in [#6229](https://github.com/gfx-rs/wgpu/pull/6229).
- Make Vulkan error handling more robust. By @teoxoy in [#6119](https://github.com/gfx-rs/wgpu/pull/6119).

#### Internal

- Tracker simplifications. By @teoxoy in [#6073](https://github.com/gfx-rs/wgpu/pull/6073) & [#6088](https://github.com/gfx-rs/wgpu/pull/6088).
- D3D12 cleanup. By @teoxoy in [#6200](https://github.com/gfx-rs/wgpu/pull/6200).
- Use `ManuallyDrop` in remaining places. By @teoxoy in [#6092](https://github.com/gfx-rs/wgpu/pull/6092).
- Move out invalidity from the `Registry`. By @teoxoy in [#6243](https://github.com/gfx-rs/wgpu/pull/6243).
- Remove `backend` from ID. By @teoxoy in [#6263](https://github.com/gfx-rs/wgpu/pull/6263).

#### HAL

Expand Down
130 changes: 80 additions & 50 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

Expand Down Expand Up @@ -1279,42 +1280,6 @@ impl<W: Write> Writer<W> {
Ok(())
}

fn put_atomic_operation(
&mut self,
pointer: Handle<crate::Expression>,
key: &str,
value: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

Ok(())
}

/// Emit code for the arithmetic expression of the dot product.
///
fn put_dot_product(
Expand Down Expand Up @@ -3182,24 +3147,65 @@ impl<W: Write> Writer<W> {
value,
result,
} => {
let context = &context.expression;

// This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
// `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
// `Some`, we are not operating on a 64-bit value, and that if we are
// operating on a 64-bit value, `result` is `None`.
write!(self.out, "{level}")?;
let fun_str = if let Some(result) = result {
let fun_key = if let Some(result) = result {
let res_name = Baked(result).to_string();
self.start_baking_expression(result, &context.expression, &res_name)?;
self.start_baking_expression(result, context, &res_name)?;
self.named_expressions.insert(result, res_name);
fun.to_msl()?
} else if context.expression.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl()
} else if context.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl_64_bit()?
} else {
fun.to_msl()?
fun.to_msl()
};

self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

// Put the atomic function invocation.
match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
}
_ => {
write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
}
}

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

// Done
writeln!(self.out, ";")?;
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
Expand Down Expand Up @@ -3827,7 +3833,33 @@ impl<W: Write> Writer<W> {
}}"
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
let arg_type_name = scalar.to_msl_name();
let called_func_name = "atomic_compare_exchange_weak_explicit";
let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;

for address_space_name in ["device", "threadgroup"] {
writeln!(
self.out,
"\
template <typename A>
{struct_name} {defined_func_name}(
{address_space_name} A *atomic_ptr,
{arg_type_name} cmp,
{arg_type_name} v
) {{
bool swapped = {NAMESPACE}::{called_func_name}(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return {struct_name}{{cmp, swapped}};
}}"
)?;
}
}
}
}

Expand Down Expand Up @@ -6065,8 +6097,8 @@ fn test_stack_size() {
}

impl crate::AtomicFunction {
fn to_msl(self) -> Result<&'static str, Error> {
Ok(match self {
const fn to_msl(self) -> &'static str {
match self {
Self::Add => "fetch_add",
Self::Subtract => "fetch_sub",
Self::And => "fetch_and",
Expand All @@ -6075,10 +6107,8 @@ impl crate::AtomicFunction {
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
})
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
}
}

fn to_msl_64_bit(self) -> Result<&'static str, Error> {
Expand Down
161 changes: 161 additions & 0 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_2 {
metal::atomic_int inner[128];
};
struct type_4 {
metal::atomic_uint inner[128];
};
struct _atomic_compare_exchange_resultSint4_ {
int old_value;
bool exchanged;
};
struct _atomic_compare_exchange_resultUint4_ {
uint old_value;
bool exchanged;
};

template <typename A>
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}

template <typename A>
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
constant uint SIZE = 128u;

kernel void test_atomic_compare_exchange_i32_(
device type_2& arr_i32_ [[user(fake0)]]
) {
uint i = 0u;
int old = {};
bool exchanged = {};
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
if (!loop_init) {
uint _e27 = i;
i = _e27 + 1u;
}
loop_init = false;
uint _e2 = i;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i;
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
old = _e8;
exchanged = false;
LOOP_IS_REACHABLE while(true) {
bool _e12 = exchanged;
if (!(_e12)) {
} else {
break;
}
{
int _e14 = old;
int new_ = as_type<int>(as_type<float>(_e14) + 1.0);
uint _e20 = i;
int _e22 = old;
_atomic_compare_exchange_resultSint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
old = _e23.old_value;
exchanged = _e23.exchanged;
}
}
}
}
return;
}


kernel void test_atomic_compare_exchange_u32_(
device type_4& arr_u32_ [[user(fake0)]]
) {
uint i_1 = 0u;
uint old_1 = {};
bool exchanged_1 = {};
bool loop_init_1 = true;
LOOP_IS_REACHABLE while(true) {
if (!loop_init_1) {
uint _e27 = i_1;
i_1 = _e27 + 1u;
}
loop_init_1 = false;
uint _e2 = i_1;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i_1;
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
old_1 = _e8;
exchanged_1 = false;
LOOP_IS_REACHABLE while(true) {
bool _e12 = exchanged_1;
if (!(_e12)) {
} else {
break;
}
{
uint _e14 = old_1;
uint new_1 = as_type<uint>(as_type<float>(_e14) + 1.0);
uint _e20 = i_1;
uint _e22 = old_1;
_atomic_compare_exchange_resultUint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
}
}
}
return;
}
Loading

0 comments on commit 64f1eac

Please sign in to comment.