Skip to content

Commit

Permalink
Remove Array type and add array function
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipFackler committed Jan 17, 2025
1 parent c47478e commit 9a47496
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 54 deletions.
8 changes: 5 additions & 3 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function JACC.parallel_reduce(
@roc groupsize=threads gridsize=1 reduce_kernel_amdgpu(
blocks, op, ret, rret)
AMDGPU.synchronize()
return Core.Array(rret)[]
return Base.Array(rret)[]
end

function JACC.parallel_reduce(
Expand All @@ -96,7 +96,7 @@ function JACC.parallel_reduce(
@roc groupsize=(Mthreads, Nthreads) gridsize=(1, 1) reduce_kernel_amdgpu_MN(
(Mblocks, Nblocks), op, ret, rret)
AMDGPU.synchronize()
return Core.Array(rret)[]
return Base.Array(rret)[]
end

function _parallel_for_amdgpu(N, f, x...)
Expand Down Expand Up @@ -426,6 +426,8 @@ function JACC.shared(x::ROCDeviceArray{T, N}) where {T, N}
return shmem
end

JACC.array_type(::AMDGPUBackend) = AMDGPU.ROCArray{T, N} where {T, N}
JACC.array_type(::AMDGPUBackend) = AMDGPU.ROCArray

JACC.array(::AMDGPUBackend, x::Base.Array) = AMDGPU.ROCArray(x)

end # module JACCAMDGPU
8 changes: 5 additions & 3 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function JACC.parallel_reduce(
N, op, ret, f, x...)
CUDA.@sync @cuda threads=threads blocks=1 shmem=shmem_size reduce_kernel_cuda(
blocks, op, ret, rret)
return Core.Array(rret)[]
return Base.Array(rret)[]
end

function JACC.parallel_reduce(
Expand All @@ -97,7 +97,7 @@ function JACC.parallel_reduce(
(M, N), op, ret, f, x...)
CUDA.@sync @cuda threads=(Mthreads, Nthreads) blocks=(1, 1) shmem=shmem_size reduce_kernel_cuda_MN(
(Mblocks, Nblocks), op, ret, rret)
return Core.Array(rret)[]
return Base.Array(rret)[]
end

function _parallel_for_cuda(N, f, x...)
Expand Down Expand Up @@ -432,6 +432,8 @@ function JACC.shared(x::CuDeviceArray{T, N}) where {T, N}
return shmem
end

JACC.array_type(::CUDABackend) = CUDA.CuArray{T, N} where {T, N}
JACC.array_type(::CUDABackend) = CUDA.CuArray

JACC.array(::CUDABackend, x::Base.Array) = CUDA.CuArray(x)

end # module JACCCUDA
8 changes: 5 additions & 3 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function JACC.parallel_reduce(
N, op, ret, f, x...)
oneAPI.@sync @oneapi items=items groups=1 reduce_kernel_oneapi(
N, op, ret, rret)
return Core.Array(rret)[]
return Base.Array(rret)[]
end

function JACC.parallel_reduce(
Expand All @@ -80,7 +80,7 @@ function JACC.parallel_reduce(
(M, N), op, ret, f, x...)
oneAPI.@sync @oneapi items=(Mitems, Nitems) groups=(1, 1) reduce_kernel_oneapi_MN(
(Mgroups, Ngroups), op, ret, rret)
return Core.Array(rret)[]
return Base.Array(rret)[]
end

function _parallel_for_oneapi(N, f, x...)
Expand Down Expand Up @@ -402,7 +402,9 @@ function JACC.shared(x::oneDeviceArray{T, N}) where {T, N}
return shmem
end

JACC.array_type(::oneAPIBackend) = oneAPI.oneArray{T, N} where {T, N}
JACC.array_type(::oneAPIBackend) = oneAPI.oneArray

JACC.array(::oneAPIBackend, x::Base.Array) = oneAPI.oneArray(x)

DefaultFloat = Union{Type, Nothing}

Expand Down
25 changes: 10 additions & 15 deletions src/JACC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ using .Experimental

get_backend(::Val{:threads}) = ThreadsBackend()

export Array, @atomic
export parallel_for
export parallel_reduce

global Array
export array_type, array
export default_float
export @atomic
export parallel_for, parallel_reduce
export shared

function parallel_for(
::ThreadsBackend, N::I, f::F, x...) where {I <: Integer, F <: Function}
Expand Down Expand Up @@ -94,25 +94,20 @@ function parallel_reduce(
return ret
end

array_type(::ThreadsBackend) = Base.Array{T, N} where {T, N}
array_type(::ThreadsBackend) = Base.Array

array(::ThreadsBackend, x::Base.Array) = x

default_float(::Any) = Float64

function shared(x::Base.Array{T, N}) where {T, N}
return x
end

struct Array{T, N} end
function (::Type{Array{T, N}})(args...; kwargs...) where {T, N}
array_type(){T, N}(args...; kwargs...)
end
function (::Type{Array{T}})(args...; kwargs...) where {T}
array_type(){T}(args...; kwargs...)
end
(::Type{Array})(args...; kwargs...) = array_type()(args...; kwargs...)

array_type() = array_type(default_backend())

array(x::Base.Array) = array(default_backend(), x)

default_float() = default_float(default_backend())

function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
Expand Down
60 changes: 30 additions & 30 deletions test/unittests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
N = 10
dims = (N)
a = round.(rand(Float32, dims) * 100)
a_expected = a .+ 5.0

a_device = JACC.Array(a)
a_device = JACC.array(a)
JACC.parallel_for(N, f, a_device)

a_expected = a .+ 5.0
@test Core.Array(a_device)a_expected rtol=1e-5
@test Base.Array(a_device)a_expected rtol=1e-5
end

@testset "AXPY" begin
Expand All @@ -36,42 +36,42 @@ end
y = round.(rand(Float32, N) * 100)
alpha = 2.5

x_device = JACC.Array(x)
y_device = JACC.Array(y)
x_device = JACC.array(x)
y_device = JACC.array(y)
JACC.parallel_for(N, axpy, alpha, x_device, y_device)

x_expected = x
seq_axpy(N, alpha, x_expected, y)

@test Core.Array(x_device)x_expected rtol=1e-1
@test Base.Array(x_device)x_expected rtol=1e-1
end

@testset "zeros" begin
N = 10
x = JACC.zeros(N)
@test eltype(x) == FloatType
@test zeros(N)Core.Array(x) rtol=1e-5
@test zeros(N)Base.Array(x) rtol=1e-5

function add_one(i, x)
@inbounds x[i] += 1
end

JACC.parallel_for(N, add_one, x)
@test ones(N)Core.Array(x) rtol=1e-5
@test ones(N)Base.Array(x) rtol=1e-5
end

@testset "ones" begin
N = 10
x = JACC.ones(N)
@test eltype(x) == FloatType
@test ones(N)Core.Array(x) rtol=1e-5
@test ones(N)Base.Array(x) rtol=1e-5

function minus_one(i, x)
@inbounds x[i] -= 1
end

JACC.parallel_for(N, minus_one, x)
@test zeros(N)Core.Array(x) rtol=1e-5
@test zeros(N)Base.Array(x) rtol=1e-5
end

@testset "AtomicCounter" begin
Expand All @@ -84,24 +84,24 @@ end
# Generate random vectors x and y of length N for the interval [0, 100]
alpha = 2.5

x = JACC.Array(round.(rand(Float32, N) * 100))
y = JACC.Array(round.(rand(Float32, N) * 100))
counter = JACC.Array{Int32}([0])
x = JACC.array(round.(rand(Float32, N) * 100))
y = JACC.array(round.(rand(Float32, N) * 100))
counter = JACC.array(Int32[0])
JACC.parallel_for(N, axpy_counter!, alpha, x, y, counter)

@test Core.Array(counter)[1] == N
@test Base.Array(counter)[1] == N
end

@testset "reduce" begin
a = JACC.Array([1 for i=1:10])
a = JACC.array([1 for i=1:10])
@test JACC.parallel_reduce(a) == 10
@test JACC.parallel_reduce(min, a) == 1
a2 = JACC.ones(Int, (2,2))
@test JACC.parallel_reduce(min, a2) == 1

SIZE = 1000
ah = randn(FloatType, SIZE)
ad = JACC.Array(ah)
ad = JACC.array(ah)
mxd = JACC.parallel_reduce(SIZE, max, (i, a) -> a[i], ad; init = -Inf)
@test mxd == maximum(ah)
mxd = JACC.parallel_reduce(max, ad)
Expand All @@ -112,7 +112,7 @@ end
@test mnd == minimum(ah)

ah2 = randn(FloatType, (SIZE, SIZE))
ad2 = JACC.Array(ah2)
ad2 = JACC.array(ah2)
mxd = JACC.parallel_reduce((SIZE, SIZE), max, (i, j, a) -> a[i, j], ad2; init = -Inf)
@test mxd == maximum(ah2)
mxd = JACC.parallel_reduce(max, ad2)
Expand Down Expand Up @@ -220,11 +220,11 @@ end

seq_scal(1_000, alpha, x)
JACC.BLAS.scal(1_000, alpha, jx)
@test xCore.Array(jx) rtol=1e-8
@test xBase.Array(jx) rtol=1e-8

seq_axpy(1_000, alpha, x, y)
JACC.BLAS.axpy(1_000, alpha, jx, jy)
@test xCore.Array(jx) atol=1e-8
@test xBase.Array(jx) atol=1e-8

r1 = seq_dot(1_000, x, y)
r2 = JACC.BLAS.dot(1_000, jx, jy)
Expand All @@ -239,8 +239,8 @@ end

seq_swap(1_000, x, y1)
JACC.BLAS.swap(1_000, jx, jy1)
@test x == Core.Array(jx)
@test y1 == Core.Array(jy1)
@test x == Base.Array(jx)
@test y1 == Base.Array(jy1)
end

@testset "Add-2D" begin
Expand All @@ -257,7 +257,7 @@ end
JACC.parallel_for((M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, M, N)
@test Core.Array(C)C_expected rtol=1e-5
@test Base.Array(C)C_expected rtol=1e-5
end

@testset "Add-3D" begin
Expand All @@ -275,7 +275,7 @@ end
JACC.parallel_for((L, M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, L, M, N)
@test Core.Array(C)C_expected rtol=1e-5
@test Base.Array(C)C_expected rtol=1e-5
end

@testset "CG" begin
Expand Down Expand Up @@ -444,17 +444,17 @@ end
w = ones(9)
t = 1.0

df = JACC.Array(f)
df1 = JACC.Array(f1)
df2 = JACC.Array(f2)
dcx = JACC.Array(cx)
dcy = JACC.Array(cy)
dw = JACC.Array(w)
df = JACC.array(f)
df1 = JACC.array(f1)
df2 = JACC.array(f2)
dcx = JACC.array(cx)
dcy = JACC.array(cy)
dw = JACC.array(w)

JACC.parallel_for(
(SIZE, SIZE), lbm_kernel, df, df1, df2, t, dw, dcx, dcy, SIZE)

lbm_threads(f, f1, f2, t, w, cx, cy, SIZE)

@test f2Core.Array(df2) rtol=1e-1
@test f2Base.Array(df2) rtol=1e-1
end

0 comments on commit 9a47496

Please sign in to comment.