Skip to content

Commit

Permalink
Add tests for atomicCompareExchangeWeak, and support both u32 and i…
Browse files Browse the repository at this point in the history
…32 atomics with it.
  • Loading branch information
aweinstock314 committed Dec 9, 2022
1 parent ee026e8 commit 8ce2ed7
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2080,14 +2080,17 @@ impl<'w> BlockContext<'w> {
)
}
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
// TODO: look this up from the atomic expression's scalar type so that it works with i32 as well
let scalar_u32 =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Uint,
width: 4,
pointer_space: None,
}));
let scalar_type_id = match *value_inner {
crate::TypeInner::Scalar { kind, width } => {
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind,
width,
pointer_space: None,
}))
}
_ => unimplemented!(),
};
let bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
Expand All @@ -2099,7 +2102,7 @@ impl<'w> BlockContext<'w> {
let cas_result_id = self.gen_id();
let equality_result_id = self.gen_id();
let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
cas_instr.set_type(scalar_u32);
cas_instr.set_type(scalar_type_id);
cas_instr.set_result(cas_result_id);
cas_instr.add_operand(pointer_id);
cas_instr.add_operand(scope_constant_id);
Expand Down
34 changes: 34 additions & 0 deletions tests/in/atomicCompareExchange.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
let SIZE: u32 = 128u;

@group(0) @binding(0)
var<storage,read_write> arr_i32: array<atomic<i32>, SIZE>;
@group(0) @binding(1)
var<storage,read_write> arr_u32: array<atomic<u32>, SIZE>;

@compute @workgroup_size(1)
fn test_atomic_compare_exchange_i32() {
for(var i = 0u; i < SIZE; i++) {
var old = atomicLoad(&arr_i32[i]);
var exchanged = false;
while(!exchanged) {
let new_ = bitcast<i32>(bitcast<f32>(old) + 1.0);
let result = atomicCompareExchangeWeak(&arr_i32[i], old, new_);
old = result.old_value;
exchanged = result.exchanged;
}
}
}

@compute @workgroup_size(1)
fn test_atomic_compare_exchange_u32() {
for(var i = 0u; i < SIZE; i++) {
var old = atomicLoad(&arr_u32[i]);
var exchanged = false;
while(!exchanged) {
let new_ = bitcast<u32>(bitcast<f32>(old) + 1.0);
let result = atomicCompareExchangeWeak(&arr_u32[i], old, new_);
old = result.old_value;
exchanged = result.exchanged;
}
}
}
188 changes: 188 additions & 0 deletions tests/out/spv/atomicCompareExchange.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 116
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %31 "test_atomic_compare_exchange_i32"
OpEntryPoint GLCompute %79 "test_atomic_compare_exchange_u32"
OpExecutionMode %31 LocalSize 1 1 1
OpExecutionMode %79 LocalSize 1 1 1
OpDecorate %12 ArrayStride 4
OpDecorate %13 ArrayStride 4
OpMemberDecorate %14 0 Offset 0
OpMemberDecorate %14 1 Offset 4
OpMemberDecorate %15 0 Offset 0
OpMemberDecorate %15 1 Offset 4
OpDecorate %16 DescriptorSet 0
OpDecorate %16 Binding 0
OpDecorate %17 Block
OpMemberDecorate %17 0 Offset 0
OpDecorate %19 DescriptorSet 0
OpDecorate %19 Binding 1
OpDecorate %20 Block
OpMemberDecorate %20 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeInt 32 0
%3 = OpConstant %4 128
%5 = OpConstant %4 0
%6 = OpConstant %4 1
%8 = OpTypeBool
%7 = OpConstantFalse %8
%10 = OpTypeFloat 32
%9 = OpConstant %10 1.0
%11 = OpTypeInt 32 1
%12 = OpTypeArray %11 %3
%13 = OpTypeArray %4 %3
%14 = OpTypeStruct %11 %8
%15 = OpTypeStruct %4 %8
%17 = OpTypeStruct %12
%18 = OpTypePointer StorageBuffer %17
%16 = OpVariable %18 StorageBuffer
%20 = OpTypeStruct %13
%21 = OpTypePointer StorageBuffer %20
%19 = OpVariable %21 StorageBuffer
%23 = OpTypePointer Function %4
%25 = OpTypePointer Function %11
%26 = OpConstantNull %11
%28 = OpTypePointer Function %8
%29 = OpConstantNull %8
%32 = OpTypeFunction %2
%33 = OpTypePointer StorageBuffer %12
%35 = OpTypePointer StorageBuffer %13
%46 = OpTypePointer StorageBuffer %11
%49 = OpConstant %11 1
%50 = OpConstant %4 64
%75 = OpConstantNull %4
%77 = OpConstantNull %8
%91 = OpTypePointer StorageBuffer %4
%31 = OpFunction %2 None %32
%30 = OpLabel
%22 = OpVariable %23 Function %5
%24 = OpVariable %25 Function %26
%27 = OpVariable %28 Function %29
%34 = OpAccessChain %33 %16 %5
OpBranch %36
%36 = OpLabel
OpBranch %37
%37 = OpLabel
OpLoopMerge %38 %40 None
OpBranch %39
%39 = OpLabel
%41 = OpLoad %4 %22
%42 = OpULessThan %8 %41 %3
OpSelectionMerge %43 None
OpBranchConditional %42 %43 %44
%44 = OpLabel
OpBranch %38
%43 = OpLabel
%45 = OpLoad %4 %22
%47 = OpAccessChain %46 %34 %45
%48 = OpAtomicLoad %11 %47 %49 %50
OpStore %24 %48
OpStore %27 %7
OpBranch %51
%51 = OpLabel
OpLoopMerge %52 %54 None
OpBranch %53
%53 = OpLabel
%55 = OpLoad %8 %27
%56 = OpLogicalNot %8 %55
OpSelectionMerge %57 None
OpBranchConditional %56 %57 %58
%58 = OpLabel
OpBranch %52
%57 = OpLabel
%59 = OpLoad %11 %24
%60 = OpBitcast %10 %59
%61 = OpFAdd %10 %60 %9
%62 = OpBitcast %11 %61
%63 = OpLoad %4 %22
%64 = OpLoad %11 %24
%66 = OpAccessChain %46 %34 %63
%67 = OpAtomicCompareExchange %11 %66 %49 %50 %50 %62 %64
%68 = OpIEqual %8 %67 %64
%65 = OpCompositeConstruct %14 %67 %68
%69 = OpCompositeExtract %11 %65 0
OpStore %24 %69
%70 = OpCompositeExtract %8 %65 1
OpStore %27 %70
OpBranch %54
%54 = OpLabel
OpBranch %51
%52 = OpLabel
OpBranch %40
%40 = OpLabel
%71 = OpLoad %4 %22
%72 = OpIAdd %4 %71 %6
OpStore %22 %72
OpBranch %37
%38 = OpLabel
OpReturn
OpFunctionEnd
%79 = OpFunction %2 None %32
%78 = OpLabel
%73 = OpVariable %23 Function %5
%74 = OpVariable %23 Function %75
%76 = OpVariable %28 Function %77
%80 = OpAccessChain %35 %19 %5
OpBranch %81
%81 = OpLabel
OpBranch %82
%82 = OpLabel
OpLoopMerge %83 %85 None
OpBranch %84
%84 = OpLabel
%86 = OpLoad %4 %73
%87 = OpULessThan %8 %86 %3
OpSelectionMerge %88 None
OpBranchConditional %87 %88 %89
%89 = OpLabel
OpBranch %83
%88 = OpLabel
%90 = OpLoad %4 %73
%92 = OpAccessChain %91 %80 %90
%93 = OpAtomicLoad %4 %92 %49 %50
OpStore %74 %93
OpStore %76 %7
OpBranch %94
%94 = OpLabel
OpLoopMerge %95 %97 None
OpBranch %96
%96 = OpLabel
%98 = OpLoad %8 %76
%99 = OpLogicalNot %8 %98
OpSelectionMerge %100 None
OpBranchConditional %99 %100 %101
%101 = OpLabel
OpBranch %95
%100 = OpLabel
%102 = OpLoad %4 %74
%103 = OpBitcast %10 %102
%104 = OpFAdd %10 %103 %9
%105 = OpBitcast %4 %104
%106 = OpLoad %4 %73
%107 = OpLoad %4 %74
%109 = OpAccessChain %91 %80 %106
%110 = OpAtomicCompareExchange %4 %109 %49 %50 %50 %105 %107
%111 = OpIEqual %8 %110 %107
%108 = OpCompositeConstruct %15 %110 %111
%112 = OpCompositeExtract %4 %108 0
OpStore %74 %112
%113 = OpCompositeExtract %8 %108 1
OpStore %76 %113
OpBranch %97
%97 = OpLabel
OpBranch %94
%95 = OpLabel
OpBranch %85
%85 = OpLabel
%114 = OpLoad %4 %73
%115 = OpIAdd %4 %114 %6
OpStore %73 %115
OpBranch %82
%83 = OpLabel
OpReturn
OpFunctionEnd
92 changes: 92 additions & 0 deletions tests/out/wgsl/atomicCompareExchange.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
struct gen___atomic_compare_exchange_result {
old_value: i32,
exchanged: bool,
}

struct gen___atomic_compare_exchange_result_1 {
old_value: u32,
exchanged: bool,
}

let SIZE: u32 = 128u;

@group(0) @binding(0)
var<storage, read_write> arr_i32_: array<atomic<i32>,SIZE>;
@group(0) @binding(1)
var<storage, read_write> arr_u32_: array<atomic<u32>,SIZE>;

@compute @workgroup_size(1, 1, 1)
fn test_atomic_compare_exchange_i32_() {
var i: u32 = 0u;
var old: i32;
var exchanged: bool;

loop {
let _e5 = i;
if (_e5 < SIZE) {
} else {
break;
}
let _e10 = i;
let _e12 = atomicLoad((&arr_i32_[_e10]));
old = _e12;
exchanged = false;
loop {
let _e16 = exchanged;
if !(_e16) {
} else {
break;
}
let _e18 = old;
let new_ = bitcast<i32>((bitcast<f32>(_e18) + 1.0));
let _e23 = i;
let _e25 = old;
let _e26 = atomicCompareExchangeWeak((&arr_i32_[_e23]), _e25, new_);
old = _e26.old_value;
exchanged = _e26.exchanged;
}
continuing {
let _e7 = i;
i = (_e7 + 1u);
}
}
return;
}

@compute @workgroup_size(1, 1, 1)
fn test_atomic_compare_exchange_u32_() {
var i_1: u32 = 0u;
var old_1: u32;
var exchanged_1: bool;

loop {
let _e5 = i_1;
if (_e5 < SIZE) {
} else {
break;
}
let _e10 = i_1;
let _e12 = atomicLoad((&arr_u32_[_e10]));
old_1 = _e12;
exchanged_1 = false;
loop {
let _e16 = exchanged_1;
if !(_e16) {
} else {
break;
}
let _e18 = old_1;
let new_1 = bitcast<u32>((bitcast<f32>(_e18) + 1.0));
let _e23 = i_1;
let _e25 = old_1;
let _e26 = atomicCompareExchangeWeak((&arr_u32_[_e23]), _e25, new_1);
old_1 = _e26.old_value;
exchanged_1 = _e26.exchanged;
}
continuing {
let _e7 = i_1;
i_1 = (_e7 + 1u);
}
}
return;
}
1 change: 1 addition & 0 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ fn convert_wgsl() {
"access",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("atomicCompareExchange", Targets::SPIRV | Targets::WGSL),
(
"padding",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
Expand Down

0 comments on commit 8ce2ed7

Please sign in to comment.