| 
 | 1 | +// Make sure that addition with carry produces expected results  | 
 | 2 | +// with and without expansion to primitive add/cmp ops for WebGPU.  | 
 | 3 | + | 
 | 4 | +// RUN: mlir-vulkan-runner %s \  | 
 | 5 | +// RUN:  --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \  | 
 | 6 | +// RUN:  --entry-point-result=void | FileCheck %s  | 
 | 7 | + | 
 | 8 | +// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \  | 
 | 9 | +// RUN:  --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \  | 
 | 10 | +// RUN:  --entry-point-result=void | FileCheck %s  | 
 | 11 | + | 
 | 12 | +// CHECK: [0, 1, 0, 42]  | 
 | 13 | +// CHECK: [0, 0, 1, 1]  | 
 | 14 | +module attributes {  | 
 | 15 | +  gpu.container_module,  | 
 | 16 | +  spirv.target_env = #spirv.target_env<  | 
 | 17 | +    #spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>  | 
 | 18 | +} {  | 
 | 19 | +  gpu.module @kernels {  | 
 | 20 | +    gpu.func @kernel_add(%arg0 : memref<4xi32>, %arg1 : memref<4xi32>, %arg2 : memref<4xi32>, %arg3 : memref<4xi8>)  | 
 | 21 | +      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {  | 
 | 22 | +      %0 = gpu.block_id x  | 
 | 23 | +      %lhs = memref.load %arg0[%0] : memref<4xi32>  | 
 | 24 | +      %rhs = memref.load %arg1[%0] : memref<4xi32>  | 
 | 25 | +      %sum, %carry = arith.addui_extended %lhs, %rhs : i32  | 
 | 26 | + | 
 | 27 | +      // We will convert to i8 as this is the smallest value we can use for  | 
 | 28 | +      // fill/print in the runner.  | 
 | 29 | +      %carry_i8 = arith.extui %carry : i1 to i8  | 
 | 30 | + | 
 | 31 | +      memref.store %sum, %arg2[%0] : memref<4xi32>  | 
 | 32 | +      memref.store %carry_i8, %arg3[%0] : memref<4xi8>  | 
 | 33 | +      gpu.return  | 
 | 34 | +    }  | 
 | 35 | +  }  | 
 | 36 | + | 
 | 37 | +  func.func @main() {  | 
 | 38 | +    %buf0 = memref.alloc() : memref<4xi32>  | 
 | 39 | +    %buf1 = memref.alloc() : memref<4xi32>  | 
 | 40 | +    %buf2 = memref.alloc() : memref<4xi32>  | 
 | 41 | +    %buf3 = memref.alloc() : memref<4xi8>  | 
 | 42 | +    %i32_0 = arith.constant 0 : i32  | 
 | 43 | +    %i8_0 = arith.constant 0 : i32  | 
 | 44 | + | 
 | 45 | +    // Initialize output buffers.  | 
 | 46 | +    %buf4 = memref.cast %buf2 : memref<4xi32> to memref<?xi32>  | 
 | 47 | +    %buf5 = memref.cast %buf3 : memref<4xi8> to memref<?xi8>  | 
 | 48 | +    call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()  | 
 | 49 | +    call @fillResource1DInt8(%buf5, %i8_0) : (memref<?xi32>, i32) -> ()  | 
 | 50 | + | 
 | 51 | +    %idx_0 = arith.constant 0 : index  | 
 | 52 | +    %idx_1 = arith.constant 1 : index  | 
 | 53 | +    %idx_4 = arith.constant 4 : index  | 
 | 54 | + | 
 | 55 | +    // Initialize input buffers.  | 
 | 56 | +    %lhs_vals = arith.constant dense<[0, 0, -2147483647, 43]> : vector<4xi32>  | 
 | 57 | +    %rhs_vals = arith.constant dense<[0, 1, 1, -2147483647]> : vector<4xi32>  | 
 | 58 | +    vector.store %lhs_vals, %buf0[%idx_0] : memref<4xi32>, vector<4xi32>  | 
 | 59 | +    vector.store %rhs_vals, %buf1[%idx_0] : memref<4xi32>, vector<4xi32>  | 
 | 60 | + | 
 | 61 | +    gpu.launch_func @kernels::@kernel_add  | 
 | 62 | +        blocks in (%idx_4, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)  | 
 | 63 | +        args(%buf0 : memref<4xi32>, %buf1 : memref<4xi32>, %buf2 : memref<4xi32>, %buf3 : memref<4xi8>)  | 
 | 64 | +    %buf_sum = memref.cast %buf4 : memref<?xi32> to memref<*xi32>  | 
 | 65 | +    %buf_carry = memref.cast %buf5 : memref<?xi32> to memref<*xi8>  | 
 | 66 | +    call @printMemrefI32(%buf_sum) : (memref<*xi32>) -> ()  | 
 | 67 | +    call @printMemrefI32(%buf_carry) : (memref<*xi8>) -> ()  | 
 | 68 | +    return  | 
 | 69 | +  }  | 
 | 70 | +  func.func private @fillResource1DInt8(%0 : memref<?xi8>, %1 : i8)  | 
 | 71 | +  func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)  | 
 | 72 | +  func.func private @printMemrefI8(%ptr : memref<*xi8>)  | 
 | 73 | +  func.func private @printMemrefI32(%ptr : memref<*xi32>)  | 
 | 74 | +}  | 
0 commit comments