Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Nov 6, 2024
1 parent b76de49 commit e4db4a2
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 21 deletions.
15 changes: 9 additions & 6 deletions test/array.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
STORAGEMODES = [Metal.PrivateStorage, Metal.SharedStorage, Metal.ManagedStorage]

const FILL_TYPES = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
Float16, Float32]
Metal.metal_support() >= v"3.1" && push!(FILL_TYPES, BFloat16)

@testset "array" begin

let arr = MtlVector{Int}(undef, 1)
Expand Down Expand Up @@ -27,8 +31,7 @@ end
@test mtl(1:3) === 1:3


# Page 22 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
# Only bfloat missing
# Section 2.1 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
supported_number_types = [Float16 => Float16,
Float32 => Float32,
Float64 => Float32,
Expand All @@ -41,6 +44,8 @@ end
UInt32 => UInt32,
UInt64 => UInt64,
UInt8 => UInt8]
Metal.metal_support() >= v"3.1" && push!(supported_number_types, BFloat16 => BFloat16)

# Test supported types and ensure only Float64 get converted to Float32
for (SrcType, TargType) in supported_number_types
@test mtl(SrcType[1]) isa MtlArray{TargType}
Expand Down Expand Up @@ -227,8 +232,7 @@ end

end

@testset "fill($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
Float16, Float32]
@testset "fill($T)" for T in FILL_TYPES
broken466a = T [Int8,UInt8]
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)

Expand Down Expand Up @@ -267,8 +271,7 @@ end
end
end

@testset "fill!($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
Float16, Float32]
@testset "fill!($T)" for T in FILL_TYPES
broken466a = T [Int8,UInt8]
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)

Expand Down
1 change: 1 addition & 0 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ end

@testset "parametrically typed" begin
typs = [Int32, Int64, Float32]
metal_support() >= v"3.1" && push!(types, BFloat16)
@testset for typ in typs
function kernel(d::MtlDeviceArray{T}, n) where {T}
t = thread_position_in_threadgroup_1d()
Expand Down
17 changes: 11 additions & 6 deletions test/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ end
return perm, y
end
end
@testset "$ftype" for ftype in (Float16, Float32)
@testset "$ftype" ftypes = [Float16, Float32]

@testset "$ftype" for ftype in ftypes
# Normal operation
for (shp,k) in [((3,1), 2), ((20,30), 5)]
@testset "$shp, k=$k" for (shp,k) in [((3,1), 2), ((20,30), 5)]
cpu_a = rand(ftype, shp...)

#topk
Expand All @@ -163,11 +165,13 @@ end
@test Array(i) == cpu_i
@test Array(v) == cpu_v
end

# test too big `k`
shp = (20,30)
k = 17

cpu_a = rand(ftype, shp...)
cpu_i, cpu_v = cpu_topk(cpu_a, k)
@testset "$shp, k=$k" begin
cpu_a = rand(ftype, shp...)
cpu_i, cpu_v = cpu_topk(cpu_a, k)

a = MtlMatrix(cpu_a)
@test_throws "MPSMatrixFindTopK does not support values of k > 16" i, v = MPS.topk(a, k)
Expand All @@ -176,7 +180,8 @@ end
i = MtlMatrix{UInt32}(undef, (k, shp[2]))
v = MtlMatrix{ftype}(undef, (k, shp[2]))

@test_throws "MPSMatrixFindTopK does not support values of k > 16" i, v = MPS.topk!(a, i, v, k)
@test_throws "MPSMatrixFindTopK does not support values of k > 16" i, v = MPS.topk!(a, i, v, k)
end
end
end

Expand Down
15 changes: 14 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,27 @@ for (rootpath, dirs, files) in walkdir(@__DIR__)
test_runners[file] = ()->include("$(@__DIR__)/$file.jl")
end
end

## GPUArrays testsuite
const gpuarr_eltypes = [Int16, Int32, Int64,
Complex{Int16}, Complex{Int32}, Complex{Int64},
Float16, Float32,
ComplexF16, ComplexF32]
const gpuarr_eltypes_nobf16 = copy(gpuarr_eltypes)

# Add BFloat16 for tests that use it
Metal.metal_support() >= v"3.1" && push!(gpuarr_eltypes, BFloat16)

for name in keys(TestSuite.tests)
if Metal.DefaultStorageMode != Metal.PrivateStorage && name == "indexing scalar"
# GPUArrays' scalar indexing tests assume that indexing is not supported
continue
end

tmp_eltypes = name in ["random"] ? gpuarr_eltypes_nobf16 : gpuarr_eltypes

push!(tests, "gpuarrays$(Base.Filesystem.path_separator)$name")
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray)
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray;eltypes=tmp_eltypes)
end
unique!(tests)

Expand Down
10 changes: 2 additions & 8 deletions test/setup.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Distributed, Test, Metal, Adapt, ObjectiveC, ObjectiveC.Foundation
using Distributed, Test, Metal, BFloat16s, Adapt, ObjectiveC, ObjectiveC.Foundation

Metal.functional() || error("Metal.jl is not functional on this system")

Expand All @@ -10,12 +10,6 @@ gpuarrays_root = dirname(dirname(gpuarrays))
include(joinpath(gpuarrays_root, "test", "testsuite.jl"))
testf(f, xs...; kwargs...) = TestSuite.compare(f, MtlArray, xs...; kwargs...)

const eltypes = [Int16, Int32, Int64,
Complex{Int16}, Complex{Int32}, Complex{Int64},
Float16, Float32,
ComplexF16, ComplexF32]
TestSuite.supported_eltypes(::Type{<:MtlArray}) = eltypes

const runtime_validation = get(ENV, "MTL_DEBUG_LAYER", "0") != "0"
const shader_validation = get(ENV, "MTL_SHADER_VALIDATION", "0") != "0"

Expand All @@ -32,7 +26,7 @@ function runtests(f, name)
# generate a temporary module to execute the tests in
mod_name = Symbol("Test", rand(1:100), "Main_", replace(name, '/' => '_'))
mod = @eval(Main, module $mod_name end)
@eval(mod, using Test, Random, Metal)
@eval(mod, using Test, Random, Metal, BFloat16s)

let id = myid()
wait(@spawnat 1 print_testworker_started(name, id))
Expand Down

0 comments on commit e4db4a2

Please sign in to comment.