-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support rand! and rand using MPS where appropriate
Also add tests
- Loading branch information
1 parent
36b4453
commit 9f55467
Showing
10 changed files
with
250 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
using Random | ||
|
||
""" | ||
MPS.RNG() | ||
A random number generator using `rand()` in a device kernel. | ||
""" | ||
mutable struct RNG <: AbstractRNG | ||
seed::UInt | ||
counter::UInt32 | ||
|
||
function RNG(seed::Integer) | ||
new(seed%UInt, 0) | ||
end | ||
RNG(seed::UInt, counter::UInt32) = new(seed, counter) | ||
end | ||
|
||
make_seed() = Base.rand(RandomDevice(), UInt) | ||
|
||
RNG() = RNG(make_seed()) | ||
|
||
Base.copy(rng::RNG) = RNG(rng.seed, rng.counter) | ||
Base.hash(rng::RNG, h::UInt) = hash(rng.seed, hash(rng.counter, h)) | ||
Base.:(==)(a::RNG, b::RNG) = (a.seed == b.seed) && (a.counter == b.counter) | ||
|
||
function Random.seed!(rng::RNG, seed::Integer) | ||
rng.seed = seed % UInt | ||
rng.counter = 0 | ||
end | ||
|
||
Random.seed!(rng::RNG) = Random.seed!(rng, make_seed()) | ||
|
||
@inline function update_state!(rng::RNG, len) | ||
new_counter = Int64(rng.counter) + len | ||
overflow, remainder = fldmod(new_counter, typemax(UInt32)) | ||
rng.seed += overflow # XXX: is this OK? | ||
rng.counter = remainder | ||
return rng | ||
end | ||
|
||
const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}() | ||
function default_rng() | ||
dev = current_device() | ||
get!(GLOBAL_RNGs, dev) do | ||
RNG() | ||
end | ||
end | ||
|
||
function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} | ||
mpsvecormat = _mpsvector_rand(A, UInt32) | ||
_mpsmat_rand!(mpsvecormat, seed = rng.seed + rng.counter) | ||
|
||
update_state!(rng,length(A)) | ||
return A | ||
end | ||
function Random.rand!(rng::RNG, A::MtlArray{Float32}) | ||
mpsvecormat = _mpsvector_rand(A, Float32) | ||
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomUniformDistributionDescriptor(0, 1), seed = rng.seed + rng.counter) | ||
|
||
update_state!(rng,length(A)) | ||
return A | ||
end | ||
function Random.randn!(rng::RNG, A::MtlArray{Float32}) | ||
mpsvecormat = _mpsvector_rand(A, Float32) | ||
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomNormalDistributionDescriptor(0, 1), seed = rng.seed + rng.counter) | ||
|
||
update_state!(rng,length(A)) | ||
return A | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,74 @@ | ||
using Random | ||
using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor, | ||
MPSMatrixRandomNormalDistributionDescriptor | ||
|
||
gpuarrays_rng() = GPUArrays.default_rng(MtlArray) | ||
mpsrand_rng() = MPS.default_rng() | ||
|
||
# GPUArrays in-place | ||
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A) | ||
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A) | ||
|
||
@inline function usempsrandom(A::MtlArray{T}) where {T} | ||
return (A.offset == 0 && | ||
(length(A) * sizeof(T) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0)) | ||
end | ||
|
||
# Use MPS random functionality where possible | ||
function Random.rand!(A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} | ||
if usempsrandom(A) | ||
@inline Random.rand!(gpuarrays_rng(), A) | ||
else | ||
@inline Random.rand!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
function Random.rand!(A::MtlArray{Float32}) | ||
if usempsrandom(A) | ||
@inline Random.rand!(mpsrand_rng(), A) | ||
else | ||
@inline Random.rand!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
function Random.randn!(A::MtlArray{Float32}) | ||
if usempsrandom(A) | ||
@inline Random.randn!(mpsrand_rng(), A) | ||
else | ||
@inline Random.randn!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
|
||
# GPUArrays out-of-place | ||
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...) | ||
rand(::Type{T}, dims::Dims; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} = | ||
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(::Type{Float32}, dims::Dims; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims),storage}(undef, dims...)) | ||
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = | ||
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(T::Type, dims::Dims; storage=DefaultStorageMode) = | ||
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
|
||
# support all dimension specifications | ||
rand(::Type{T}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} = | ||
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(::Type{Float32}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...)) | ||
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = | ||
Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) | ||
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# untyped out-of-place | ||
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...)) | ||
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) | ||
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# seeding | ||
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed) | ||
function seed!(seed=Base.rand(UInt64)) | ||
Random.seed!(gpuarrays_rng(), seed) | ||
Random.seed!(mpsrand_rng(), seed) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,102 @@ | ||
using Random | ||
|
||
const RAND_TYPES = [Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, | ||
UInt64] | ||
const RANDN_TYPES = [Float16, Float32] | ||
const INPLACE_TUPLES = [[(rand!, T) for T in RAND_TYPES]; | ||
[(randn!, T) for T in RANDN_TYPES]] | ||
const OOPLACE_TUPLES = [[(Metal.rand, T) for T in RAND_TYPES]; | ||
[(Metal.randn, T) for T in RANDN_TYPES]; | ||
[(rand, T) for T in RAND_TYPES]; | ||
[(randn, T) for T in RANDN_TYPES]] | ||
|
||
@testset "rand" begin | ||
# in-place | ||
@testset "in-place" begin | ||
@testset "$f with $T" for (f, T) in INPLACE_TUPLES | ||
@testset "$d" for d in (1, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16)) | ||
A = MtlArray{T}(undef, d) | ||
fill!(A, T(0)) | ||
f(A) | ||
@test Metal.usempsrandom(A) == | ||
((prod(d) * sizeof(T)) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0) | ||
@test !iszero(collect(A)) | ||
end | ||
end | ||
end | ||
|
||
# in-place contiguous views | ||
@testset "in-place for views" begin | ||
@testset "$f with $T" for (f, T) in INPLACE_TUPLES | ||
alen = 100 | ||
A = MtlArray{T}(undef, alen) | ||
function test_view!(X::MtlArray{T}, idx; shouldusemps) where {T} | ||
fill!(X, T(0)) | ||
view_X = @view X[idx] | ||
f(view_X) | ||
cpuX = collect(X) | ||
@test Metal.usempsrandom(view_X) == shouldusemps | ||
@test !iszero(cpuX[idx]) | ||
@test iszero(cpuX[1:alen .∉ Ref(idx)]) | ||
return | ||
end | ||
|
||
# Test when view offset is 0 and buffer size not multiple of 16 | ||
@testset "Off == 0, buf % 16 != 0" begin | ||
test_view!(A, 1:51; shouldusemps=false) | ||
end | ||
|
||
# Test when view offset is 0 and buffer size is multiple of 16 | ||
@testset "Off == 0, buf % 16 == 0" begin | ||
test_view!(A, 1:32; shouldusemps=true) | ||
end | ||
|
||
# Test when view offset is not 0 nor multiple of 16 and buffer size not multiple of 16 | ||
@testset "Off != 0, buf % 16 != 0" begin | ||
test_view!(A, 3:51; shouldusemps=false) | ||
end | ||
|
||
# Test when view offset is multiple of 16 and buffer size not multiple of 16 | ||
@testset "Off % 16 == 0, buf % 16 != 0" begin | ||
test_view!(A, 17:51; shouldusemps=false) | ||
end | ||
|
||
# in-place | ||
for (f,T) in ((rand!,Float16), | ||
(rand!,Float32), | ||
(randn!,Float16), | ||
(randn!,Float32)), | ||
d in (2, (2,2), (2,2,2), 3, (3,3), (3,3,3)) | ||
A = MtlArray{T}(undef, d) | ||
fill!(A, T(0)) | ||
f(A) | ||
@test !iszero(collect(A)) | ||
end | ||
|
||
# out-of-place, with implicit type | ||
for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32)), | ||
args in ((2,), (2, 2), (3,), (3, 3)) | ||
A = f(args...) | ||
@test eltype(A) == T | ||
end | ||
|
||
# out-of-place, with type specified | ||
for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32), | ||
(rand,Float32), (randn,Float32)), | ||
args in ((T, 2), (T, 2, 2), (T, (2, 2)), (T, 3), (T, 3, 3), (T, (3, 3))) | ||
A = f(args...) | ||
@test eltype(A) == T | ||
end | ||
|
||
## seeding | ||
Metal.seed!(1) | ||
a = Metal.rand(Int32, 1) | ||
Metal.seed!(1) | ||
b = Metal.rand(Int32, 1) | ||
@test iszero(collect(a) - collect(b)) | ||
# Test when view offset is multiple of 16 and buffer size multiple of 16 | ||
@testset "Off % 16 == 0, buf % 16 == 0" begin | ||
test_view!(A, 17:32; shouldusemps=false) | ||
end | ||
end | ||
end | ||
# out-of-place, with implicit type | ||
@testset "out-of-place" begin | ||
@testset "$f with implicit type" for (f, T) in | ||
((Metal.rand, Float32), (Metal.randn, Float32)) | ||
@testset "args" for args in ((1,), (3,), (3, 3), (16,), (16, 16)) | ||
A = f(args...) | ||
@test eltype(A) == T | ||
end | ||
end | ||
|
||
# out-of-place, with type specified | ||
@testset "$f with $T" for (f, T) in OOPLACE_TUPLES | ||
@testset "$args" for args in ((T, 1), | ||
(T, 3), | ||
(T, 3, 3), | ||
(T, (3, 3)), | ||
(T, 16), | ||
(T, 16, 16), | ||
(T, (16, 16))) | ||
A = f(args...) | ||
@test eltype(A) == T | ||
end | ||
end | ||
end | ||
## seeding | ||
@testset "Seeding" begin | ||
Metal.seed!(1) | ||
a = Metal.rand(Int32, 1) | ||
Metal.seed!(1) | ||
b = Metal.rand(Int32, 1) | ||
@test iszero(collect(a) - collect(b)) | ||
end | ||
end # testset |