Skip to content

Commit

Permalink
Add support for WGSL's atomicCompareExchangeWeak with the `__atomic…
Browse files Browse the repository at this point in the history
…_compare_exchange_result` struct, and add SPIR-V codegen for it.

Partially addresses gpuweb/gpuweb#2113, #1755.
  • Loading branch information
aweinstock314 committed Dec 9, 2022
1 parent bf4e62b commit ee026e8
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 13 deletions.
43 changes: 41 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2079,8 +2079,47 @@ impl<'w> BlockContext<'w> {
value_id,
)
}
crate::AtomicFunction::Exchange { compare: Some(_) } => {
return Err(Error::FeatureNotImplemented("atomic CompareExchange"));
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
// TODO: look this up from the atomic expression's scalar type so that it works with i32 as well
let scalar_u32 =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Uint,
width: 4,
pointer_space: None,
}));
let bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
pointer_space: None,
}));

let cas_result_id = self.gen_id();
let equality_result_id = self.gen_id();
let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
cas_instr.set_type(scalar_u32);
cas_instr.set_result(cas_result_id);
cas_instr.add_operand(pointer_id);
cas_instr.add_operand(scope_constant_id);
cas_instr.add_operand(semantics_id); // semantics if equal
cas_instr.add_operand(semantics_id); // semantics if not equal
cas_instr.add_operand(value_id);
cas_instr.add_operand(self.cached[cmp]);
block.body.push(cas_instr);
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
equality_result_id,
cas_result_id,
self.cached[cmp],
));
Instruction::composite_construct(
result_type_id,
id,
&[cas_result_id, equality_result_id],
)
}
};

Expand Down
44 changes: 42 additions & 2 deletions src/front/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,7 @@ impl Parser {
crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult {
kind,
width,
comparison: false,
comparison: None,
},
_ => return Err(Error::InvalidAtomicOperandType(value_span)),
};
Expand Down Expand Up @@ -1857,10 +1857,50 @@ impl Parser {

let expression = match *ctx.resolve_type(value)? {
crate::TypeInner::Scalar { kind, width } => {
let bool_ty = ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
NagaSpan::UNDEFINED,
);
let scalar_ty = ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar { kind, width },
},
NagaSpan::UNDEFINED,
);
let struct_ty = ctx.types.insert(
crate::Type {
name: Some("__atomic_compare_exchange_result".to_string()),
inner: crate::TypeInner::Struct {
members: vec![
crate::StructMember {
name: Some("old_value".to_string()),
ty: scalar_ty,
binding: None,
offset: 0,
},
crate::StructMember {
name: Some("exchanged".to_string()),
ty: bool_ty,
binding: None,
offset: 4,
},
],
span: 8,
},
},
NagaSpan::UNDEFINED,
);
crate::Expression::AtomicResult {
kind,
width,
comparison: true,
comparison: Some(struct_ty),
}
}
_ => return Err(Error::InvalidAtomicOperandType(value_span)),
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,7 @@ pub enum Expression {
AtomicResult {
kind: ScalarKind,
width: Bytes,
comparison: bool,
comparison: Option<Handle<Type>>,
},
/// Get the length of an array.
/// The expression must resolve to a pointer to an array with a dynamic size.
Expand Down
8 changes: 2 additions & 6 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,12 +649,8 @@ impl<'a> ResolveContext<'a> {
width,
comparison,
} => {
if comparison {
TypeResolution::Value(Ti::Vector {
size: crate::VectorSize::Bi,
kind,
width,
})
if let Some(struct_ty) = comparison {
TypeResolution::Handle(struct_ty)
} else {
TypeResolution::Value(Ti::Scalar { kind, width })
}
Expand Down
10 changes: 8 additions & 2 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,17 @@ impl super::Validator {
.into_other());
}
match context.expressions[result] {
//TODO: support atomic result with comparison
//TODO: does the result of an atomicCompareExchange need additional validation, or does the existing validation for
// the struct type it returns suffice?
crate::Expression::AtomicResult {
kind,
width,
comparison: false,
comparison: Some(_),
} if kind == ptr_kind && width == ptr_width => {}
crate::Expression::AtomicResult {
kind,
width,
comparison: None,
} if kind == ptr_kind && width == ptr_width => {}
_ => {
return Err(AtomicError::ResultTypeMismatch(result)
Expand Down

0 comments on commit ee026e8

Please sign in to comment.