Skip to content

Commit

Permalink
Merge pull request #25 from arhik/main
Browse files Browse the repository at this point in the history
Tiled matrix multiplication example demonstrating shared memory api and usage
  • Loading branch information
arhik authored Mar 27, 2024
2 parents abc74dc + e933882 commit 176e3dd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/matmul_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using Infiltrator
using Test

function naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
gIdx = localId.x
gIdy = localId.y
gIdx = globalId.x
gIdy = globalId.y
gId = xDims.x*gIdy + gIdx
out[gId] = 0.0
sum = 0.0
Expand Down
47 changes: 36 additions & 11 deletions examples/tiled_matmul_kernel.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Revise
using WGPUCompute

using Test
using StaticArrays

const Vec2{T} = SVector{2, T}
Expand All @@ -11,21 +11,42 @@ const Mat3{T} = SMatrix{3, 3, T, 9}
const Mat4{T} = SMatrix{4, 4, T, 16}
const Vec{N, T} = SVector{N, T}


x = WgpuArray{Float32, 2}(rand(16, 16));
y = WgpuArray{Float32, 2}(rand(16, 16));

function tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
gIdx = localId.x
gIdy = localId.y
lIdx = localId.x
lIdy = localId.y
gIdx = globalId.x
gIdy = globalId.y

#set out matrix to zero
gId = xDims.x*gIdy + gIdx
out[gId] = 0.0

# set local variable = 0.0
sum = 0.0
for i in 0:xDims.y
xIdx = xDims.x*i + gIdx
yIdx = yDims.x*gIdy + i
sum = sum + x[xIdx]*y[yIdx]

for tileId in 0:numWorkgroups.y
# copy block from x to shared memory
xId = workgroupId.x*workgroupDims.x + localId.x
yId = tileId*workgroupDims.y + localId.y
sId = localId.y*workgroupDims.x + localId.x
shmem1[sId] = x[yId*xDims.x + xId]

# copy block from y to shared memory
xId = tileId*workgroupDims.x + localId.x
yId = workgroupId.y*workgroupDims.y + localId.y
shmem2[sId] = y[yId*yDims.x + xId]
synchronize()

# block sums for each tid
for i in 0:workgroupDims.y
sum = sum + shmem1[i*workgroupDims.x + localId.x]*shmem2[localId.y*workgroupDims.x + i]
end
synchronize()
end

out[gId] = sum
end

Expand All @@ -39,9 +60,9 @@ function tiled_matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
wgSize = (16, 16)
@wgpukernel(
launch=true,
workgroupSizes=outSize,
workgroupCount=(1, 1),
shmem=(:shmem1=>(Vec4{Float32}, (4, 4)), :shmem2=>(Float32, (4, 4))),
workgroupSizes=(4, 4),
workgroupCount=(4, 4),
shmem=(:shmem1=>(Float32, (4, 4)), :shmem2=>(Float32, (4, 4))),
tiled_matmul_kernel(x, y, out)
)
return out
Expand All @@ -50,3 +71,7 @@ end
Base.:*(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N} = tiled_matmul(x, y)

z = x*y

z_cpu = (x |> collect)*(y |> collect)

@test z_cpu (z |> collect)

0 comments on commit 176e3dd

Please sign in to comment.