Skip to content

Commit

Permalink
Merge pull request #151 from PhilipFackler/oneapi-testing
Browse files Browse the repository at this point in the history
Update oneAPI testing
  • Loading branch information
PhilipFackler authored Nov 24, 2024
2 parents af91915 + 81c5570 commit 80e57c2
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 119 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
margin = 80
style = "sciml"
format_doctrings = true
separate_kwargs_with_semicolon = true
16 changes: 8 additions & 8 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ function _parallel_for_amdgpu_LMN((L, M, N), f, x...)
end

function _parallel_reduce_amdgpu(N, op, ret, f, x...)
shared_mem = @ROCStaticLocalArray(Float64, 512)
shared_mem = @ROCStaticLocalArray(eltype(ret), 512)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
ti = workitemIdx().x
tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
shared_mem[ti] = 0.0

if i <= N
Expand Down Expand Up @@ -179,10 +179,10 @@ function _parallel_reduce_amdgpu(N, op, ret, f, x...)
end

function reduce_kernel_amdgpu(N, op, red, ret)
shared_mem = @ROCStaticLocalArray(Float64, 512)
shared_mem = @ROCStaticLocalArray(eltype(ret), 512)
i = workitemIdx().x
ii = i
tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
if N > 512
while ii <= N
tmp = op(tmp, @inbounds red[ii])
Expand Down Expand Up @@ -233,15 +233,15 @@ function reduce_kernel_amdgpu(N, op, red, ret)
end

function _parallel_reduce_amdgpu_MN((M, N), op, ret, f, x...)
shared_mem = @ROCStaticLocalArray(Float64, 256)
shared_mem = @ROCStaticLocalArray(eltype(ret), 256)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
j = (workgroupIdx().y - 1) * workgroupDim().y + workitemIdx().y
ti = workitemIdx().x
tj = workitemIdx().y
bi = workgroupIdx().x
bj = workgroupIdx().y

tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
sid = ((ti - 1) * 16) + tj
shared_mem[sid] = tmp

Expand Down Expand Up @@ -285,13 +285,13 @@ function _parallel_reduce_amdgpu_MN((M, N), op, ret, f, x...)
end

function reduce_kernel_amdgpu_MN((M, N), op, red, ret)
shared_mem = @ROCStaticLocalArray(Float64, 256)
shared_mem = @ROCStaticLocalArray(eltype(ret), 256)
i = workitemIdx().x
j = workitemIdx().y
ii = i
jj = j

tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
sid = ((i - 1) * 16) + j
shared_mem[sid] = tmp

Expand Down
24 changes: 12 additions & 12 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ function JACC.parallel_reduce(
blocks = ceil(Int, N / threads)
ret = fill!(CUDA.CuArray{typeof(init)}(undef, blocks), init)
rret = CUDA.CuArray([init])
CUDA.@sync @cuda threads=threads blocks=blocks shmem=512 * sizeof(Float64) _parallel_reduce_cuda(
CUDA.@sync @cuda threads=threads blocks=blocks shmem=512 * sizeof(typeof(init)) _parallel_reduce_cuda(
N, op, ret, f, x...)
CUDA.@sync @cuda threads=threads blocks=1 shmem=512 * sizeof(Float64) reduce_kernel_cuda(
CUDA.@sync @cuda threads=threads blocks=1 shmem=512 * sizeof(typeof(init)) reduce_kernel_cuda(
blocks, op, ret, rret)
return Core.Array(rret)[]
end
Expand All @@ -93,10 +93,10 @@ function JACC.parallel_reduce(
rret = CUDA.CuArray([init])
CUDA.@sync @cuda threads=(Mthreads, Nthreads) blocks=(Mblocks, Nblocks) shmem=16 *
16 *
sizeof(Float64) _parallel_reduce_cuda_MN(
sizeof(typeof(init)) _parallel_reduce_cuda_MN(
(M, N), op, ret, f, x...)
CUDA.@sync @cuda threads=(Mthreads, Nthreads) blocks=(1, 1) shmem=16 * 16 *
sizeof(Float64) reduce_kernel_cuda_MN(
sizeof(typeof(init)) reduce_kernel_cuda_MN(
(Mblocks, Nblocks), op, ret, rret)
return Core.Array(rret)[]
end
Expand Down Expand Up @@ -129,10 +129,10 @@ function _parallel_for_cuda_LMN((L, M, N), f, x...)
end

function _parallel_reduce_cuda(N, op, ret, f, x...)
shared_mem = @cuDynamicSharedMem(Float64, 512)
shared_mem = @cuDynamicSharedMem(eltype(ret), 512)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
ti = threadIdx().x
tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
shared_mem[ti] = 0.0

if i <= N
Expand Down Expand Up @@ -180,10 +180,10 @@ function _parallel_reduce_cuda(N, op, ret, f, x...)
end

function reduce_kernel_cuda(N, op, red, ret)
shared_mem = @cuDynamicSharedMem(Float64, 512)
shared_mem = @cuDynamicSharedMem(eltype(ret), 512)
i = threadIdx().x
ii = i
tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
if N > 512
while ii <= N
tmp = op(tmp, @inbounds red[ii])
Expand Down Expand Up @@ -234,15 +234,15 @@ function reduce_kernel_cuda(N, op, red, ret)
end

function _parallel_reduce_cuda_MN((M, N), op, ret, f, x...)
shared_mem = @cuDynamicSharedMem(Float64, 16*16)
shared_mem = @cuDynamicSharedMem(eltype(ret), 16*16)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
ti = threadIdx().x
tj = threadIdx().y
bi = blockIdx().x
bj = blockIdx().y

tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
sid = ((ti - 1) * 16) + tj
shared_mem[sid] = tmp

Expand Down Expand Up @@ -286,13 +286,13 @@ function _parallel_reduce_cuda_MN((M, N), op, ret, f, x...)
end

function reduce_kernel_cuda_MN((M, N), op, red, ret)
shared_mem = @cuDynamicSharedMem(Float64, 16*16)
shared_mem = @cuDynamicSharedMem(eltype(ret), 16*16)
i = threadIdx().x
j = threadIdx().y
ii = i
jj = j

tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
sid = ((i - 1) * 16) + j
shared_mem[sid] = tmp

Expand Down
51 changes: 35 additions & 16 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module JACCONEAPI

using JACC, oneAPI
using JACC, oneAPI, oneAPI.oneL0

# overloaded array functions
include("array.jl")
Expand Down Expand Up @@ -39,7 +39,7 @@ function JACC.parallel_for(
end

function JACC.parallel_for(
::oneAPIBackend, (L, M, N)::Tuple{I, I}, f::F, x...) where {
::oneAPIBackend, (L, M, N)::Tuple{I, I, I}, f::F, x...) where {
I <: Integer, F <: Function}
maxPossibleItems = 16
Litems = min(M, maxPossibleItems)
Expand All @@ -58,8 +58,8 @@ function JACC.parallel_reduce(
numItems = 256
items = min(N, numItems)
groups = ceil(Int, N / items)
ret = oneAPI.zeros(Float32, groups)
rret = oneAPI.zeros(Float32, 1)
ret = oneAPI.zeros(typeof(init), groups)
rret = oneAPI.zeros(typeof(init), 1)
oneAPI.@sync @oneapi items=items groups=groups _parallel_reduce_oneapi(
N, op, ret, f, x...)
oneAPI.@sync @oneapi items=items groups=1 reduce_kernel_oneapi(
Expand All @@ -74,8 +74,8 @@ function JACC.parallel_reduce(
Nitems = min(N, numItems)
Mgroups = ceil(Int, M / Mitems)
Ngroups = ceil(Int, N / Nitems)
ret = oneAPI.zeros(Float32, (Mgroups, Ngroups))
rret = oneAPI.zeros(Float32, 1)
ret = oneAPI.zeros(typeof(init), (Mgroups, Ngroups))
rret = oneAPI.zeros(typeof(init), 1)
oneAPI.@sync @oneapi items=(Mitems, Nitems) groups=(Mgroups, Ngroups) _parallel_reduce_oneapi_MN(
(M, N), op, ret, f, x...)
oneAPI.@sync @oneapi items=(Mitems, Nitems) groups=(1, 1) reduce_kernel_oneapi_MN(
Expand Down Expand Up @@ -111,12 +111,10 @@ function _parallel_for_oneapi_LMN((L, M, N), f, x...)
end

function _parallel_reduce_oneapi(N, op, ret, f, x...)
#shared_mem = oneLocalArray(Float32, 256)
shared_mem = oneLocalArray(Float64, 256)
shared_mem = oneLocalArray(eltype(ret), 256)
i = get_global_id(0)
ti = get_local_id(0)
#tmp::Float32 = 0.0
tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
shared_mem[ti] = 0.0
if i <= N
tmp = @inbounds f(i, x...)
Expand Down Expand Up @@ -160,10 +158,10 @@ function _parallel_reduce_oneapi(N, op, ret, f, x...)
end

function reduce_kernel_oneapi(N, op, red, ret)
shared_mem = oneLocalArray(Float64, 256)
shared_mem = oneLocalArray(eltype(ret), 256)
i = get_global_id()
ii = i
tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
if N > 256
while ii <= N
tmp = op(tmp, @inbounds red[ii])
Expand Down Expand Up @@ -210,15 +208,15 @@ function reduce_kernel_oneapi(N, op, red, ret)
end

function _parallel_reduce_oneapi_MN((M, N), op, ret, f, x...)
shared_mem = oneLocalArray(Float64, 16 * 16)
shared_mem = oneLocalArray(eltype(ret), 16 * 16)
i = get_global_id(0)
j = get_global_id(1)
ti = get_local_id(0)
tj = get_local_id(1)
bi = get_group_id(0)
bj = get_group_id(1)

tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
sid = ((ti - 1) * 16) + tj
shared_mem[sid] = tmp

Expand Down Expand Up @@ -262,13 +260,13 @@ function _parallel_reduce_oneapi_MN((M, N), op, ret, f, x...)
end

function reduce_kernel_oneapi_MN((M, N), op, red, ret)
shared_mem = oneLocalArray(Float64, 16 * 16)
shared_mem = oneLocalArray(eltype(ret), 16 * 16)
i = get_local_id(0)
j = get_local_id(1)
ii = i
jj = j

tmp::Float64 = 0.0
tmp::eltype(ret) = 0.0
sid = ((i - 1) * 16) + j
shared_mem[sid] = tmp

Expand Down Expand Up @@ -408,4 +406,25 @@ end

JACC.array_type(::oneAPIBackend) = oneAPI.oneArray{T, N} where {T, N}

DefaultFloat = Union{Type, Nothing}

function _get_default_float()
if oneL0.module_properties(device()).fp64flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64
return Float64
else
@info """Float64 unsupported on the current device.
Default float for JACC.jl changed to Float32.
"""
return Float32
end
end

function JACC.default_float(::oneAPIBackend)
global DefaultFloat
if isa(nothing, DefaultFloat)
DefaultFloat = _get_default_float()
end
return DefaultFloat
end

end # module JACCONEAPI
4 changes: 2 additions & 2 deletions ext/JACCONEAPI/array.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

function JACC.zeros(::oneAPIBackend, T, dims...)
function JACC.zeros(::oneAPIBackend, ::Type{T}, dims...) where {T}
return oneAPI.zeros(T, dims...)
end

function JACC.ones(::oneAPIBackend, T, dims...)
function JACC.ones(::oneAPIBackend, ::Type{T}, dims...) where {T}
return oneAPI.ones(T, dims...)
end
8 changes: 6 additions & 2 deletions src/JACC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ end

array_type(::ThreadsBackend) = Base.Array{T, N} where {T, N}

default_float(::Any) = Float64

function shared(x::Base.Array{T, N}) where {T, N}
return x
end
Expand All @@ -108,6 +110,8 @@ end

array_type() = array_type(default_backend())

default_float() = default_float(default_backend())

function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
return parallel_for(default_backend(), N, f, x...)
end
Expand All @@ -128,7 +132,7 @@ function parallel_reduce(
end

function parallel_reduce(N::Integer, f::Function, x...)
return parallel_reduce(N, +, f, x...; init = zero(Float64))
return parallel_reduce(N, +, f, x...; init = zero(default_float()))
end

function parallel_reduce((M, N)::Tuple{I, I}, op, f::F, x...;
Expand All @@ -137,7 +141,7 @@ function parallel_reduce((M, N)::Tuple{I, I}, op, f::F, x...;
end

function parallel_reduce((M, N)::Tuple{Integer, Integer}, f::Function, x...)
return parallel_reduce((M, N), +, f, x...; init = zero(Float64))
return parallel_reduce((M, N), +, f, x...; init = zero(default_float()))
end

end # module JACC
6 changes: 4 additions & 2 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ function ones(::ThreadsBackend, T, dims...)
return Base.ones(T, dims...)
end

zeros(T, dims...) = zeros(default_backend(), T, dims...)
zeros(::Type{T}, dims...) where {T} = zeros(default_backend(), T, dims...)
ones(::Type{T}, dims...) where {T} = ones(default_backend(), T, dims...)

ones(T, dims...) = ones(default_backend(), T, dims...)
zeros(dims...) = zeros(default_float(), dims...)
ones(dims...) = ones(default_float(), dims...)
29 changes: 29 additions & 0 deletions test/JACCTests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module JACCTests

import JACC
using ReTest

const backend = JACC.JACCPreferences.backend

@static if backend == "cuda"
using CUDA
include("tests_cuda.jl")
elseif backend == "amdgpu"
using AMDGPU
include("tests_amdgpu.jl")
elseif backend == "oneapi"
using oneAPI
include("tests_oneapi.jl")
elseif backend == "threads"
include("tests_threads.jl")
end

const FloatType = JACC.default_float()
using ChangePrecision
@changeprecision FloatType begin

include("unittests.jl")

end # @changeprecision

end
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[deps]
ChangePrecision = "3cb15238-376d-56a3-8042-d33272777c9a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down
Loading

0 comments on commit 80e57c2

Please sign in to comment.