Skip to content

Commit

Permalink
[wgsl-in/spv-out] Add support for WGSL's atomicCompareExchangeWeak (#…
Browse files Browse the repository at this point in the history
…2165)

* Add support for WGSL's `atomicCompareExchangeWeak` with the `__atomic_compare_exchange_result` struct, and add SPIR-V codegen for it.

Partially addresses gpuweb/gpuweb#2113, #1755.

* Add tests for `atomicCompareExchangeWeak`, and support both u32 and i32 atomics with it.

* More thorough typechecking of the struct returned by `atomicCompareExchangeWeak`.
  • Loading branch information
aweinstock314 authored Dec 13, 2022
1 parent 8f1d82f commit 5d8fc3f
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 44 deletions.
46 changes: 44 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2078,8 +2078,50 @@ impl<'w> BlockContext<'w> {
value_id,
)
}
crate::AtomicFunction::Exchange { compare: Some(_) } => {
return Err(Error::FeatureNotImplemented("atomic CompareExchange"));
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
let scalar_type_id = match *value_inner {
crate::TypeInner::Scalar { kind, width } => {
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind,
width,
pointer_space: None,
}))
}
_ => unimplemented!(),
};
let bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
pointer_space: None,
}));

let cas_result_id = self.gen_id();
let equality_result_id = self.gen_id();
let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
cas_instr.set_type(scalar_type_id);
cas_instr.set_result(cas_result_id);
cas_instr.add_operand(pointer_id);
cas_instr.add_operand(scope_constant_id);
cas_instr.add_operand(semantics_id); // semantics if equal
cas_instr.add_operand(semantics_id); // semantics if not equal
cas_instr.add_operand(value_id);
cas_instr.add_operand(self.cached[cmp]);
block.body.push(cas_instr);
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
equality_result_id,
cas_result_id,
self.cached[cmp],
));
Instruction::composite_construct(
result_type_id,
id,
&[cas_result_id, equality_result_id],
)
}
};

Expand Down
52 changes: 48 additions & 4 deletions src/front/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1630,8 +1630,13 @@ impl Parser {

let expression = match *ctx.resolve_type(value)? {
crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult {
kind,
width,
ty: ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar { kind, width },
},
NagaSpan::UNDEFINED,
),
comparison: false,
},
_ => return Err(Error::InvalidAtomicOperandType(value_span)),
Expand Down Expand Up @@ -1861,9 +1866,48 @@ impl Parser {

let expression = match *ctx.resolve_type(value)? {
crate::TypeInner::Scalar { kind, width } => {
let bool_ty = ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
NagaSpan::UNDEFINED,
);
let scalar_ty = ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar { kind, width },
},
NagaSpan::UNDEFINED,
);
let struct_ty = ctx.types.insert(
crate::Type {
name: Some("__atomic_compare_exchange_result".to_string()),
inner: crate::TypeInner::Struct {
members: vec![
crate::StructMember {
name: Some("old_value".to_string()),
ty: scalar_ty,
binding: None,
offset: 0,
},
crate::StructMember {
name: Some("exchanged".to_string()),
ty: bool_ty,
binding: None,
offset: 4,
},
],
span: 8,
},
},
NagaSpan::UNDEFINED,
);
crate::Expression::AtomicResult {
kind,
width,
ty: struct_ty,
comparison: true,
}
}
Expand Down
6 changes: 1 addition & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1401,11 +1401,7 @@ pub enum Expression {
/// Result of calling another function.
CallResult(Handle<Function>),
/// Result of an atomic operation.
AtomicResult {
kind: ScalarKind,
width: Bytes,
comparison: bool,
},
AtomicResult { ty: Handle<Type>, comparison: bool },
/// Get the length of an array.
/// The expression must resolve to a pointer to an array with a dynamic size.
///
Expand Down
16 changes: 1 addition & 15 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,21 +644,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult {
kind,
width,
comparison,
} => {
if comparison {
TypeResolution::Value(Ti::Vector {
size: crate::VectorSize::Bi,
kind,
width,
})
} else {
TypeResolution::Value(Ti::Scalar { kind, width })
}
}
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(),
crate::Expression::Relational { fun, argument } => match fun {
Expand Down
35 changes: 23 additions & 12 deletions src/valid/expression.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#[cfg(feature = "validate")]
use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags};
use super::{
compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages,
TypeFlags,
};
#[cfg(feature = "validate")]
use crate::arena::UniqueArena;

Expand Down Expand Up @@ -115,8 +118,8 @@ pub enum ExpressionError {
WrongArgumentCount(crate::MathFunction),
#[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
#[error("Atomic result type can't be {0:?} of {1} bytes")]
InvalidAtomicResultType(crate::ScalarKind, crate::Bytes),
#[error("Atomic result type can't be {0:?}")]
InvalidAtomicResultType(Handle<crate::Type>),
#[error("Shader requires capability {0:?}")]
MissingCapabilities(super::Capabilities),
}
Expand Down Expand Up @@ -1389,19 +1392,27 @@ impl super::Validator {
ShaderStages::all()
}
E::CallResult(function) => other_infos[function.index()].available_stages,
E::AtomicResult {
kind,
width,
comparison: _,
} => {
let good = match kind {
crate::ScalarKind::Uint | crate::ScalarKind::Sint => {
self.check_width(kind, width)
E::AtomicResult { ty, comparison } => {
let scalar_predicate = |ty: &crate::TypeInner| match ty {
&crate::TypeInner::Scalar {
kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint),
width,
} => self.check_width(kind, width),
_ => false,
};
let good = match &module.types[ty].inner {
ty if !comparison => scalar_predicate(ty),
&crate::TypeInner::Struct { ref members, .. } if comparison => {
validate_atomic_compare_exchange_struct(
&module.types,
members,
scalar_predicate,
)
}
_ => false,
};
if !good {
return Err(ExpressionError::InvalidAtomicResultType(kind, width));
return Err(ExpressionError::InvalidAtomicResultType(ty));
}
ShaderStages::all()
}
Expand Down
29 changes: 23 additions & 6 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
use crate::arena::{Arena, UniqueArena};
use crate::arena::{BadHandle, Handle};

#[cfg(feature = "validate")]
use super::validate_atomic_compare_exchange_struct;

use super::{
analyzer::{UniformityDisruptor, UniformityRequirements},
ExpressionError, FunctionInfo, ModuleInfo,
Expand Down Expand Up @@ -363,12 +366,26 @@ impl super::Validator {
.into_other());
}
match context.expressions[result] {
//TODO: support atomic result with comparison
crate::Expression::AtomicResult {
kind,
width,
comparison: false,
} if kind == ptr_kind && width == ptr_width => {}
crate::Expression::AtomicResult { ty, comparison }
if {
let scalar_predicate = |ty: &crate::TypeInner| {
*ty == crate::TypeInner::Scalar {
kind: ptr_kind,
width: ptr_width,
}
};
match &context.types[ty].inner {
ty if !comparison => scalar_predicate(ty),
&crate::TypeInner::Struct { ref members, .. } if comparison => {
validate_atomic_compare_exchange_struct(
context.types,
members,
scalar_predicate,
)
}
_ => false,
}
} => {}
_ => {
return Err(AtomicError::ResultTypeMismatch(result)
.with_span_handle(result, context.expressions)
Expand Down
17 changes: 17 additions & 0 deletions src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,20 @@ impl Validator {
Ok(mod_info)
}
}

#[cfg(feature = "validate")]
fn validate_atomic_compare_exchange_struct(
types: &UniqueArena<crate::Type>,
members: &[crate::StructMember],
scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
) -> bool {
members.len() == 2
&& members[0].name.as_deref() == Some("old_value")
&& scalar_predicate(&types[members[0].ty].inner)
&& members[1].name.as_deref() == Some("exchanged")
&& types[members[1].ty].inner
== crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
}
}
34 changes: 34 additions & 0 deletions tests/in/atomicCompareExchange.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
let SIZE: u32 = 128u;

@group(0) @binding(0)
var<storage,read_write> arr_i32: array<atomic<i32>, SIZE>;
@group(0) @binding(1)
var<storage,read_write> arr_u32: array<atomic<u32>, SIZE>;

@compute @workgroup_size(1)
fn test_atomic_compare_exchange_i32() {
for(var i = 0u; i < SIZE; i++) {
var old = atomicLoad(&arr_i32[i]);
var exchanged = false;
while(!exchanged) {
let new_ = bitcast<i32>(bitcast<f32>(old) + 1.0);
let result = atomicCompareExchangeWeak(&arr_i32[i], old, new_);
old = result.old_value;
exchanged = result.exchanged;
}
}
}

@compute @workgroup_size(1)
fn test_atomic_compare_exchange_u32() {
for(var i = 0u; i < SIZE; i++) {
var old = atomicLoad(&arr_u32[i]);
var exchanged = false;
while(!exchanged) {
let new_ = bitcast<u32>(bitcast<f32>(old) + 1.0);
let result = atomicCompareExchangeWeak(&arr_u32[i], old, new_);
old = result.old_value;
exchanged = result.exchanged;
}
}
}
Loading

0 comments on commit 5d8fc3f

Please sign in to comment.