diff --git a/ext/JACCAMDGPU/JACCAMDGPU.jl b/ext/JACCAMDGPU/JACCAMDGPU.jl index ea485b4..9200bcc 100644 --- a/ext/JACCAMDGPU/JACCAMDGPU.jl +++ b/ext/JACCAMDGPU/JACCAMDGPU.jl @@ -78,7 +78,7 @@ function JACC.parallel_reduce( @roc groupsize=threads gridsize=1 reduce_kernel_amdgpu( blocks, op, ret, rret) AMDGPU.synchronize() - return Core.Array(rret)[] + return Base.Array(rret)[] end function JACC.parallel_reduce( @@ -96,7 +96,7 @@ function JACC.parallel_reduce( @roc groupsize=(Mthreads, Nthreads) gridsize=(1, 1) reduce_kernel_amdgpu_MN( (Mblocks, Nblocks), op, ret, rret) AMDGPU.synchronize() - return Core.Array(rret)[] + return Base.Array(rret)[] end function _parallel_for_amdgpu(N, f, x...) @@ -426,6 +426,8 @@ function JACC.shared(x::ROCDeviceArray{T, N}) where {T, N} return shmem end -JACC.array_type(::AMDGPUBackend) = AMDGPU.ROCArray{T, N} where {T, N} +JACC.array_type(::AMDGPUBackend) = AMDGPU.ROCArray + +JACC.array(::AMDGPUBackend, x::Base.Array) = AMDGPU.ROCArray(x) end # module JACCAMDGPU diff --git a/ext/JACCCUDA/JACCCUDA.jl b/ext/JACCCUDA/JACCCUDA.jl index 1db4a98..e913f09 100644 --- a/ext/JACCCUDA/JACCCUDA.jl +++ b/ext/JACCCUDA/JACCCUDA.jl @@ -80,7 +80,7 @@ function JACC.parallel_reduce( N, op, ret, f, x...) CUDA.@sync @cuda threads=threads blocks=1 shmem=shmem_size reduce_kernel_cuda( blocks, op, ret, rret) - return Core.Array(rret)[] + return Base.Array(rret)[] end function JACC.parallel_reduce( @@ -97,7 +97,7 @@ function JACC.parallel_reduce( (M, N), op, ret, f, x...) CUDA.@sync @cuda threads=(Mthreads, Nthreads) blocks=(1, 1) shmem=shmem_size reduce_kernel_cuda_MN( (Mblocks, Nblocks), op, ret, rret) - return Core.Array(rret)[] + return Base.Array(rret)[] end function _parallel_for_cuda(N, f, x...) @@ -432,6 +432,8 @@ function JACC.shared(x::CuDeviceArray{T, N}) where {T, N} return shmem end -JACC.array_type(::CUDABackend) = CUDA.CuArray{T, N} where {T, N} +JACC.array_type(::CUDABackend) = CUDA.CuArray + +JACC.array(::CUDABackend, x::Base.Array) = CUDA.CuArray(x) end # module JACCCUDA diff --git a/ext/JACCONEAPI/JACCONEAPI.jl b/ext/JACCONEAPI/JACCONEAPI.jl index 82cbe46..04ee20b 100644 --- a/ext/JACCONEAPI/JACCONEAPI.jl +++ b/ext/JACCONEAPI/JACCONEAPI.jl @@ -64,7 +64,7 @@ function JACC.parallel_reduce( N, op, ret, f, x...) oneAPI.@sync @oneapi items=items groups=1 reduce_kernel_oneapi( N, op, ret, rret) - return Core.Array(rret)[] + return Base.Array(rret)[] end function JACC.parallel_reduce( @@ -80,7 +80,7 @@ function JACC.parallel_reduce( (M, N), op, ret, f, x...) oneAPI.@sync @oneapi items=(Mitems, Nitems) groups=(1, 1) reduce_kernel_oneapi_MN( (Mgroups, Ngroups), op, ret, rret) - return Core.Array(rret)[] + return Base.Array(rret)[] end function _parallel_for_oneapi(N, f, x...) @@ -402,7 +402,9 @@ function JACC.shared(x::oneDeviceArray{T, N}) where {T, N} return shmem end -JACC.array_type(::oneAPIBackend) = oneAPI.oneArray{T, N} where {T, N} +JACC.array_type(::oneAPIBackend) = oneAPI.oneArray + +JACC.array(::oneAPIBackend, x::Base.Array) = oneAPI.oneArray(x) DefaultFloat = Union{Type, Nothing} diff --git a/src/JACC.jl b/src/JACC.jl index 8dda660..f6509f7 100644 --- a/src/JACC.jl +++ b/src/JACC.jl @@ -28,11 +28,11 @@ using .Experimental get_backend(::Val{:threads}) = ThreadsBackend() -export Array, @atomic -export parallel_for -export parallel_reduce - -global Array +export array_type, array +export default_float +export @atomic +export parallel_for, parallel_reduce +export shared function parallel_for( ::ThreadsBackend, N::I, f::F, x...) where {I <: Integer, F <: Function} @@ -94,7 +94,9 @@ function parallel_reduce( return ret end -array_type(::ThreadsBackend) = Base.Array{T, N} where {T, N} +array_type(::ThreadsBackend) = Base.Array + +array(::ThreadsBackend, x::Base.Array) = x default_float(::Any) = Float64 @@ -102,17 +104,10 @@ function shared(x::Base.Array{T, N}) where {T, N} return x end -struct Array{T, N} end -function (::Type{Array{T, N}})(args...; kwargs...) where {T, N} - array_type(){T, N}(args...; kwargs...) -end -function (::Type{Array{T}})(args...; kwargs...) where {T} - array_type(){T}(args...; kwargs...) -end -(::Type{Array})(args...; kwargs...) = array_type()(args...; kwargs...) - array_type() = array_type(default_backend()) +array(x::Base.Array) = array(default_backend(), x) + default_float() = default_float(default_backend()) function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} diff --git a/test/unittests.jl b/test/unittests.jl index 5619954..8c8b6ef 100644 --- a/test/unittests.jl +++ b/test/unittests.jl @@ -9,12 +9,12 @@ N = 10 dims = (N) a = round.(rand(Float32, dims) * 100) + a_expected = a .+ 5.0 - a_device = JACC.Array(a) + a_device = JACC.array(a) JACC.parallel_for(N, f, a_device) - a_expected = a .+ 5.0 - @test Core.Array(a_device)≈a_expected rtol=1e-5 + @test Base.Array(a_device)≈a_expected rtol=1e-5 end @testset "AXPY" begin @@ -36,42 +36,42 @@ end y = round.(rand(Float32, N) * 100) alpha = 2.5 - x_device = JACC.Array(x) - y_device = JACC.Array(y) + x_device = JACC.array(x) + y_device = JACC.array(y) JACC.parallel_for(N, axpy, alpha, x_device, y_device) x_expected = x seq_axpy(N, alpha, x_expected, y) - @test Core.Array(x_device)≈x_expected rtol=1e-1 + @test Base.Array(x_device)≈x_expected rtol=1e-1 end @testset "zeros" begin N = 10 x = JACC.zeros(N) @test eltype(x) == FloatType - @test zeros(N)≈Core.Array(x) rtol=1e-5 + @test zeros(N)≈Base.Array(x) rtol=1e-5 function add_one(i, x) @inbounds x[i] += 1 end JACC.parallel_for(N, add_one, x) - @test ones(N)≈Core.Array(x) rtol=1e-5 + @test ones(N)≈Base.Array(x) rtol=1e-5 end @testset "ones" begin N = 10 x = JACC.ones(N) @test eltype(x) == FloatType - @test ones(N)≈Core.Array(x) rtol=1e-5 + @test ones(N)≈Base.Array(x) rtol=1e-5 function minus_one(i, x) @inbounds x[i] -= 1 end JACC.parallel_for(N, minus_one, x) - @test zeros(N)≈Core.Array(x) rtol=1e-5 + @test zeros(N)≈Base.Array(x) rtol=1e-5 end @testset "AtomicCounter" begin @@ -84,16 +84,16 @@ end # Generate random vectors x and y of length N for the interval [0, 100] alpha = 2.5 - x = JACC.Array(round.(rand(Float32, N) * 100)) - y = JACC.Array(round.(rand(Float32, N) * 100)) - counter = JACC.Array{Int32}([0]) + x = JACC.array(round.(rand(Float32, N) * 100)) + y = JACC.array(round.(rand(Float32, N) * 100)) + counter = JACC.array(Int32[0]) JACC.parallel_for(N, axpy_counter!, alpha, x, y, counter) - @test Core.Array(counter)[1] == N + @test Base.Array(counter)[1] == N end @testset "reduce" begin - a = JACC.Array([1 for i=1:10]) + a = JACC.array([1 for i=1:10]) @test JACC.parallel_reduce(a) == 10 @test JACC.parallel_reduce(min, a) == 1 a2 = JACC.ones(Int, (2,2)) @@ -101,7 +101,7 @@ end SIZE = 1000 ah = randn(FloatType, SIZE) - ad = JACC.Array(ah) + ad = JACC.array(ah) mxd = JACC.parallel_reduce(SIZE, max, (i, a) -> a[i], ad; init = -Inf) @test mxd == maximum(ah) mxd = JACC.parallel_reduce(max, ad) @@ -112,7 +112,7 @@ end @test mnd == minimum(ah) ah2 = randn(FloatType, (SIZE, SIZE)) - ad2 = JACC.Array(ah2) + ad2 = JACC.array(ah2) mxd = JACC.parallel_reduce((SIZE, SIZE), max, (i, j, a) -> a[i, j], ad2; init = -Inf) @test mxd == maximum(ah2) mxd = JACC.parallel_reduce(max, ad2) @@ -220,11 +220,11 @@ end seq_scal(1_000, alpha, x) JACC.BLAS.scal(1_000, alpha, jx) - @test x≈Core.Array(jx) rtol=1e-8 + @test x≈Base.Array(jx) rtol=1e-8 seq_axpy(1_000, alpha, x, y) JACC.BLAS.axpy(1_000, alpha, jx, jy) - @test x≈Core.Array(jx) atol=1e-8 + @test x≈Base.Array(jx) atol=1e-8 r1 = seq_dot(1_000, x, y) r2 = JACC.BLAS.dot(1_000, jx, jy) @@ -239,8 +239,8 @@ end seq_swap(1_000, x, y1) JACC.BLAS.swap(1_000, jx, jy1) - @test x == Core.Array(jx) - @test y1 == Core.Array(jy1) + @test x == Base.Array(jx) + @test y1 == Base.Array(jy1) end @testset "Add-2D" begin @@ -257,7 +257,7 @@ end JACC.parallel_for((M, N), add!, A, B, C) C_expected = Float32(2.0) .* ones(Float32, M, N) - @test Core.Array(C)≈C_expected rtol=1e-5 + @test Base.Array(C)≈C_expected rtol=1e-5 end @testset "Add-3D" begin @@ -275,7 +275,7 @@ end JACC.parallel_for((L, M, N), add!, A, B, C) C_expected = Float32(2.0) .* ones(Float32, L, M, N) - @test Core.Array(C)≈C_expected rtol=1e-5 + @test Base.Array(C)≈C_expected rtol=1e-5 end @testset "CG" begin @@ -444,17 +444,17 @@ end w = ones(9) t = 1.0 - df = JACC.Array(f) - df1 = JACC.Array(f1) - df2 = JACC.Array(f2) - dcx = JACC.Array(cx) - dcy = JACC.Array(cy) - dw = JACC.Array(w) + df = JACC.array(f) + df1 = JACC.array(f1) + df2 = JACC.array(f2) + dcx = JACC.array(cx) + dcy = JACC.array(cy) + dw = JACC.array(w) JACC.parallel_for( (SIZE, SIZE), lbm_kernel, df, df1, df2, t, dw, dcx, dcy, SIZE) lbm_threads(f, f1, f2, t, w, cx, cy, SIZE) - @test f2≈Core.Array(df2) rtol=1e-1 + @test f2≈Base.Array(df2) rtol=1e-1 end