diff --git a/test/blas.jl b/test/blas.jl index 5f296941..88305128 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -5,19 +5,19 @@ using LinearAlgebra CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH) @testset "BLAS API" begin - @testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], + @testset "WMMA GEMM $(AB_type)*$(AB_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], transpose_b = [false, true], - (A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)] + (AB_type, CD_type, min_dimension) in [(Float16, Float16, 256), (Float16, Float32, 128)] @testcase "(M = $M, N = $N, K = $K)" for M in min_dimension .* [1, 2], N in min_dimension .* [1, 2], K in min_dimension .* [1, 2] - alpha = rand(A_type) + alpha = rand(AB_type) beta = rand(CD_type) - a_h = rand(A_type, (M, K)) / sqrt(A_type(K)) - b_h = rand(B_type, (K, N)) / sqrt(B_type(K)) + a_h = rand(AB_type, (M, K)) / sqrt(AB_type(K)) + b_h = rand(AB_type, (K, N)) / sqrt(AB_type(K)) c_h = rand(CD_type, (M, N)) # Transpose input if necessary @@ -33,7 +33,7 @@ CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH) c_cublas = CuArray(c_h) CUDA.CUBLAS.gemmEx!(!transpose_a ? 'N' : 'T', !transpose_b ? 'N' : 'T', alpha, a, b, beta, c_cublas) - @test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(A_type)))); + @test Array(c_gemmkernels) ≈ Array(c_cublas) rtol=sqrt(eps(AB_type)) end end end diff --git a/test/matmul.jl b/test/matmul.jl index c513f479..90fc0baa 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -6,7 +6,7 @@ using LinearAlgebra ################################################################################ @testset "Matmul API" begin - @testset "FPU GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for + @testset "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for (A_type, B_type, CD_type, min_dimension) in [ (Float16, Float16, Float32, 128), (Float32, Float32, Float32, 128), (Float32, Float32, Float64, 128), (Float64, Float64, Float64, 128), (Int16, Int16, Int16, 128), (Int32, Int32, Int32, 128), (Int64, Int64, Int64, 128), @@ -63,10 +63,11 @@ using LinearAlgebra new_a_h = transpose_a ? transpose(a_h) : a_h new_b_h = transpose_b ? transpose(b_h) : b_h + mul!(c_h, new_a_h, new_b_h, alpha, beta) if A_type <: Integer - @test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d))) + @test c_h ≈ Array(d) else - @test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type)))) + @test c_h ≈ Array(d) rtol=sqrt(eps(A_type)) end end end @@ -120,11 +121,12 @@ using LinearAlgebra new_a_h = transpose_a ? transpose(a_h) : a_h new_b_h = transpose_b ? transpose(b_h) : b_h - @test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type)))) + mul!(c_h, new_a_h, new_b_h, alpha, beta) + @test c_h ≈ Array(d) rtol=sqrt(eps(A_type)) end end - @testset "TROPICAL GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for + @testset "TROPICAL GEMM $(A_type)*$(B_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for (A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128)], transpose_a = [false, true], transpose_b = [false, true], @@ -172,12 +174,12 @@ using LinearAlgebra GemmKernels.matmul(a, b, c, d, conf; kernel = Kernel.matmul_pipelined) - @test all(isapprox.(d_h, Array(d); rtol = sqrt(eps(A_type)))) + @test d_h ≈ Array(d) rtol=sqrt(eps(A_type)) end end - @testset "WMMA GEMM $(AB_type)*$(AB_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], + @testset "WMMA GEMM $(AB_type)*$(AB_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], transpose_b = [false, true], (AB_type, CD_type, min_dimension) in [(Float16, Float16, 256), (Float16, Float32, 128)] @testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2,2,1], [1,1,2], [2,2,2]], [[2048, 2048, 2048]]) @@ -220,7 +222,8 @@ using LinearAlgebra new_a_h = transpose_a ? transpose(a_h) : a_h new_b_h = transpose_b ? transpose(b_h) : b_h - @test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(AB_type)))) + mul!(c_h, new_a_h, new_b_h, alpha, beta) + @test c_h ≈ Array(d) rtol=sqrt(eps(AB_type)) end end @@ -271,7 +274,8 @@ using LinearAlgebra new_a_h = transpose_a ? transpose(a_h) : a_h new_b_h = transpose_b ? transpose(b_h) : b_h - @test all(isapprox.(Float32.(new_a_h) * Float32.(new_b_h) + c_h .+ Array(bias), Array(d); rtol = sqrt(eps(Float16)))) + mul!(c_h, new_a_h, new_b_h, true, true) + @test c_h .+ Array(bias) ≈ Array(d) rtol=sqrt(eps(Float16)) end end @@ -281,7 +285,7 @@ using LinearAlgebra transpose_a = false - a_h = rand(Float16, M); + a_h = rand(Float16, M) b_h = rand(Float16, (K, N)) / sqrt(Float16(K)) c_h = rand(Float32, (M, N)) @@ -315,7 +319,8 @@ using LinearAlgebra new_a_h = transpose_a ? transpose(a_h) : a_h new_b_h = transpose_b ? transpose(b_h) : b_h - @test all(isapprox.(Float32.(Diagonal(new_a_h)) * Float32.(new_b_h) + c_h, Array(d); rtol = sqrt(eps(Float16)))) + mul!(c_h, Diagonal(new_a_h), new_b_h, true, true) + @test c_h ≈ Array(d) rtol=sqrt(eps(Float16)) end end @@ -323,18 +328,18 @@ using LinearAlgebra transpose_b = [false, true] @testcase "(M = $M, N = $N, K = $K)" for (M, N, K) = [(128, 128, 128), (256, 256, 256), (2048, 2048, 2048)] - a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K)); - b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K)); - c_h = rand(Complex{Float32}, (M, N)); + a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K)) + b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K)) + c_h = rand(Complex{Float32}, (M, N)) # Transpose input if necessary a_h = transpose_a ? transpose(a_h) : a_h b_h = transpose_b ? transpose(b_h) : b_h - a = CuArray(a_h); - b = CuArray(b_h); - c = CuArray(c_h); - d = similar(c); + a = CuArray(a_h) + b = CuArray(b_h) + c = CuArray(c_h) + d = similar(c) conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), @@ -378,22 +383,21 @@ using LinearAlgebra new_a_h = transpose_a ? transpose(new_a_h) : new_a_h new_b_h = transpose_b ? transpose(new_b_h) : new_b_h - # TODO: Figure out why changing this to a * b + c = d instead of a * b = d - c - # makes tests fail for CC (see #19). - @test all(isapprox.(Complex{Float32}.(new_a_h) * Complex{Float32}.(new_b_h), Array(d) - c_h; rtol=sqrt(eps(Float16)))); + mul!(c_h, new_a_h, new_b_h, true, true) + @test c_h ≈ Array(d) rtol=sqrt(eps(Float16)) end end @testset "WMMA Dual GEMM" begin @testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in [(128, 128, 128), (256, 256, 256), (2048, 2048, 2048)] - a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K)); - b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K)); - c_h = rand(Complex{Float32}, (M, N)); + a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K)) + b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K)) + c_h = rand(Complex{Float32}, (M, N)) - a = CuArray(a_h); - b = CuArray(b_h); - c = CuArray(c_h); - d = similar(c); + a = CuArray(a_h) + b = CuArray(b_h) + c = CuArray(c_h) + d = similar(c) conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), @@ -432,7 +436,8 @@ using LinearAlgebra c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h) d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d)) - @test all(isapprox.(a_dual * b_dual + c_dual, d_dual; rtol=sqrt(eps(Float16)))); + mul!(c_dual, a_dual, b_dual, true, true) + @test c_dual ≈ d_dual rtol=sqrt(eps(Float16)) end end end diff --git a/test/runtests.jl b/test/runtests.jl index d2dfb96f..eadaeb22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,4 +22,4 @@ withenv("JULIA_NUM_THREADS" => 1, "OPENBLAS_NUM_THREADS" => 1) do end @everywhere using XUnit -runtests("tests.jl") +runtests("tests.jl", ARGS...)