Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu #3

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
**/*.mem
**/*.swp
**/*.zip
statprof
190 changes: 113 additions & 77 deletions src/MPIQR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MPIQR

using LinearAlgebra, Base.Threads, Base.Iterators, Combinatorics
using MPI, MPIClusterManagers, ProgressMeter
using LinearAlgebra.BLAS

alphafactor(x::Real) = -sign(x)
alphafactor(x::Complex) = -exp(im * angle(x))
Expand All @@ -19,12 +20,17 @@ struct MPIQRMatrix{T,M<:AbstractMatrix{T}} <: AbstractMatrix{T}
end

function validblocksizes(numcols::Integer, commsize::Integer)::Vector{Int}
iszero(n ÷ c) || return [0]
iszero(numcols ÷ commsize) || return [0]
return findall(iszero(numcols % i) for i in 1:(numcols ÷ commsize))
end

function localcolumns(rnk, n, blocksize, commsize)
return vcat(collect(partition(collect(1:n), blocksize))[rnk + 1:commsize:end]...)
output = vcat(collect(partition(collect(1:n), blocksize))[rnk + 1:commsize:end]...)
@assert length(output) > 0 (rnk, n, blocksize, commsize)
@assert minimum(output) >= 1
@assert maximum(output) <= n
@assert issorted(output)
return output
end
localcolumns(A::MPIQRMatrix) = A.localcolumns
localmatrix(A::MPIQRMatrix) = A.localmatrix
Expand All @@ -37,33 +43,44 @@ function MPIQRMatrix(localmatrix::AbstractMatrix, globalsize; blocksize=1, comm
m, n = globalsize
@assert mod(n, blocksize) == 0
localcols = localcolumns(rnk, n, blocksize, commsize)
@assert length(localcols) > 0
@assert minimum(localcols) >= 1
@assert maximum(localcols) <= n
@assert issorted(localcols)
colsets = Vector{Set{Int}}()
for r in 0:commsize-1
push!(colsets, Set(localcolumns(r, n, blocksize, commsize)))
end
@assert size(localmatrix, 2) == length(localcols)
if size(localmatrix, 2) != length(localcols)
throw(ArgumentError(
"This rank's matrix must have the right number of local columns"))
end

lookupop(j) = (x = searchsortedfirst(localcols, j); isnothing(x) ? 0 : x)
columnlookup = Vector{Int}([lookupop(j) for j in 1:n])
@assert minimum(columnlookup) >= 0
@assert maximum(columnlookup) <= n
return MPIQRMatrix(localmatrix, globalsize, localcols, columnlookup, colsets, blocksize, rnk, comm, commsize)
end
columnowner(A::MPIQRMatrix, j) = findfirst(in(j, s) for s in A.colsets) - 1

function columnowner(A::MPIQRMatrix, j)::Int
for (i, cols) in enumerate(A.colsets)
in(j, cols) && return i - 1
end
@assert false "Shouldn't be able to get here"
return -1
end

Base.size(A::MPIQRMatrix) = A.globalsize
Base.getindex(A::MPIQRMatrix, i, j) = A.localmatrix[i, localcolindex(A, j)]

function Base.setindex!(A::MPIQRMatrix, v::Number, i, j)
return A.localmatrix[i, localcolindex(A, j)] = v
end
function Base.:*(A::MPIQRMatrix{T}, x::AbstractVector{U}) where {T,U}
y = A.localmatrix * x[A.localcolumns]
return MPI.Allreduce(y, +, A.comm)

# define these for dispatch purposes
Base.:*(A::MPIQRMatrix{T,M}, x::AbstractVector) where {T,M} = _mul(A, x)
Base.:*(A::MPIQRMatrix{T,M}, x::AbstractMatrix) where {T,M} = _mul(A, x)
function _mul(A::MPIQR.MPIQRMatrix, x)
y = A.localmatrix * x[A.localcolumns, :] # can't use views for gpus
MPI.Allreduce!(y, +, A.comm)
return y
end

localsize(A::MPIQRMatrix, dim=nothing) = size(A.localmatrix, dim)
Expand Down Expand Up @@ -95,6 +112,7 @@ end
Base.first(cii::ColumnIntersectionIterator) = cii.localcolumns[first(cii.indices)]
Base.last(cii::ColumnIntersectionIterator) = cii.localcolumns[last(cii.indices)]
Base.length(cii::ColumnIntersectionIterator) = length(cii.indices)
Base.view(cii::ColumnIntersectionIterator) = view(cii.localcolumns, cii.indices)

function Base.intersect(A::MPIQRMatrix, cols)
indexa = searchsortedfirst(A.localcolumns, first(cols))
Expand All @@ -106,8 +124,11 @@ function Base.intersect(A::MPIQRMatrix, cols)
return output
end

const IsBitsUnion = Union{Float32, Float64, ComplexF32, ComplexF64,
Vector{Float32}, Vector{Float64}, Vector{ComplexF32}, Vector{ComplexF64}}
function Base.view(H::MPIQRMatrix, is, js)
lja = localcolindex(H, first(js))
ljz = localcolindex(H, last(js))
return view(H.localmatrix, is, lja:ljz)
end

function hotloopviews(H::MPIQRMatrix, Hj::AbstractMatrix, Hr, y, j, ja, jz, m, n,
js = intersect(H, ja:jz))
Expand Down Expand Up @@ -148,17 +169,18 @@ columns of Hj. This function calculates the combinations of these dot products.
```julia
```
"""
function unrecursedcoeffs(N, A)
A >= N && return Any[(N, N)]
output = Any[(A, N)]
function unrecursedcoeffs(N::Int, A::Int)
A >= N && return [[N, N]]
output = [[A, N]]
for i in 1:N-1, c in combinations(A+1:N-1, i)
push!(output, (A, c..., N))
push!(output, [A, c..., N])
end
return reverse(output)
reverse!(output)
return output
end

"""
recurse!(H::AbstractMatrix,Hj::AbstractArray{T},Hr,y) where {T<:IsBitsUnion}
recurse!(H::AbstractMatrix,Hj::AbstractArray{T},Hr,y) where {T}

In stead of applying the columns of `Hj` to H` sequentially, it is better to
calculate the effective recursive action of `Hj` on `H` and store that in `Hr`
Expand All @@ -175,82 +197,94 @@ such that `Hr` can be applied to `H` in one big gemm call.
```julia
```
"""
function recurse!(H::AbstractMatrix, Hj::AbstractArray{T}, Hr, y) where {T<:IsBitsUnion}
dots = zeros(T, size(Hj, 2), size(Hj, 2)) # faster than a dict
@views @inbounds for i in 1:size(Hj, 2), j in 1:i
dots[i, j] = dot(Hj[:, i], Hj[:, j])
end
function recurse!(H, Hj, Hr, y)
T = promote_type(eltype(H), eltype(Hj))
dots = similar(H, size(Hj, 2), size(Hj, 2)) # faster than a dict
mul!(dots, Hj', Hj, true, false)
dots = Matrix(dots) # a no-op in CPU code

BLAS.gemm!('C', 'N', true, H, Hj, false, y)
# gemm!('C', 'N', true, H, Hj, false, y)
mul!(y, H', Hj, true, false)

copyto!(Hr, Hj)

# this is complicated, I know, but the tests pass!
# It's easier to verify by deploying this logic with symbolic quantities
# and viewing the output
@views @inbounds for ii in 0:size(Hj, 2) - 1
for i in ii + 1:size(Hj, 2) - 1
for urc in unrecursedcoeffs(i, ii)
factor = one(T)
@views @inbounds for ii in 1:size(Hj, 2)
for i in ii + 1:size(Hj, 2)
summand = zero(T)
for urc in unrecursedcoeffs(i - 1, ii - 1)
factor = -one(T) * (-1)^length(urc)
@inbounds for j in 2:length(urc)
factor *= dots[urc[j] + 1, urc[j-1] + 1]
end
BLAS.axpy!(-(-1)^length(urc) * factor, view(Hj, :, i + 1), view(Hr, :, ii + 1))
summand += factor
end
axpy!(summand, view(Hj, :, i), view(Hr, :, ii))
end
end
end

function hotloop!(H::AbstractMatrix, Hj::AbstractArray{T}, Hr, y) where {T<:IsBitsUnion}
function hotloop!(H::AbstractMatrix, Hj::AbstractArray{T}, Hr, y) where {T}

recurse!(H, Hj, Hr, y)

BLAS.gemm!('N', 'C', -one(T), Hr, y, true, H) # H .-= Hj * y'
# BLAS.gemm!('N', 'C', -one(T), Hr, y, true, H) # H .-= Hj * y'
mul!(H, Hr, y', -1, true)

return nothing
end


function householder!(H::MPIQRMatrix{T}, α=zeros(T, size(H, 2)); verbose=false,
progress=FakeProgress()) where T
function norm_bcast(x::AbstractArray)
return sqrt(reduce(+, Base.Broadcast.Broadcasted(
z->real(z) * real(z) + imag(z) * imag(z), (x,));
init=zero(real(eltype(x))) ))
end

function householder!(H::MPIQRMatrix{T,M}, α=similar(H.localmatrix, size(H, 2));
verbose=false, progress=FakeProgress()) where {T,M}
m, n = size(H)
@assert m >= n
bs = blocksize(H) # the blocksize / tilesize of contiguous columns on each rank
Hj = zeros(T, m, bs) # the H column(s)
Hr = zeros(T, m, bs) # the H column(s)
Hjcopy = bs > 1 ? zeros(T, m) : Hj # copy of the H column(s)
Hj = similar(H.localmatrix, m, bs) # the H column(s)
Hr = similar(H.localmatrix, m, bs) # the H column(s)
t1 = t2 = t3 = t4 = t5 = 0.0
# work array for the BLAS call
y = zeros(eltype(H), localcolsize(H, 1:n), bs)
y = similar(H.localmatrix, localcolsize(H, 1:n), bs)

# send the first column(s) of H to Hj on all ranks
j = 1
src = columnowner(H, j)
if H.rank == src
@views copyto!(Hj[j:m, :], H[j:m, j:j - 1 + bs])
copyto!(view(Hj, j:m, :), view(H, j:m, j:j - 1 + bs))
end
MPI.Bcast!(view(Hj, j:m, :), H.comm; root=src)

tmp = zeros(T, m * bs)
@inbounds @views for j in 1:bs:n
tmp = similar(H.localmatrix, m * bs)
@inbounds for j in 1:bs:n
colowner = columnowner(H, j)

# process all the first bs column(s) of H
@inbounds for Δj in 0:bs-1
t1 += @elapsed @views begin
s = norm(Hj[j + Δj:m, 1 + Δj])
α[j + Δj] = s * alphafactor(Hj[j + Δj, 1 + Δj])
f = 1 / sqrt(s * (s + abs(Hj[j + Δj, 1 + Δj])))
Hj[j:j + Δj - 1, 1 + Δj] .= 0
Hj[j + Δj, 1 + Δj] -= α[j + Δj]
Hj[j + Δj:m, 1 + Δj] .*= f
t1 += @elapsed begin
jj = j + Δj
#s = norm(view(Hj, jj:m, 1 + Δj))
viewHj = view(Hj, jj:m, 1 + Δj)
#s = norm_bcast(viewHj)
s = sqrt(real(dot(viewHj, viewHj)))
view(α, jj) .= s .* alphafactor.(view(Hj, jj, 1 + Δj))
f = 1 / sqrt(s * (s + abs(sum(view(Hj, jj, 1 + Δj)))))
view(Hj, j:jj - 1, 1 + Δj) .= 0
view(Hj, jj, 1 + Δj) .-= view(α, jj)
view(Hj, jj:m, 1 + Δj) .*= f
end

t2 += @elapsed bs > 1 && copyto!(view(Hjcopy, j+Δj:m, 1), view(Hj, j+Δj:m, 1 + Δj)) # prevent data race
t3 += @elapsed hotloop!(view(Hj, j+Δj:m, 1 + Δj:bs), view(Hjcopy, j+Δj:m, 1), view(Hr, j+Δj:m, 1), view(y, 1 + Δj:bs))
t3 += @elapsed hotloop!(view(Hj, j+Δj:m, 1 + Δj:bs), view(Hj, j+Δj:m, 1+Δj), view(Hr, j+Δj:m, 1), view(y, 1 + Δj:bs))

t2 += @elapsed if H.rank == colowner
@views copyto!(H[j + Δj:m, j + Δj:j-1+bs], Hj[j + Δj:m, 1 + Δj:bs])
copyto!(view(H, j + Δj:m, j + Δj:j-1+bs), view(Hj, j + Δj:m, 1 + Δj:bs))
end
end

Expand All @@ -263,10 +297,7 @@ function householder!(H::MPIQRMatrix{T}, α=zeros(T, size(H, 2)); verbose=false,
src = columnowner(H, j + bs)
reqs = Vector{MPI.Request}()
if H.rank == src
k = 0
for (cj, jj) in enumerate(j + bs:j - 1 + 2bs), (ci, ii) in enumerate(j+bs:m)
@inbounds tmp[k+=1] = H[ii, jj]
end
@inbounds tmp .= reshape(view(H, j+bs:m, j+bs:j-1+2bs), (m-j-bs+1) * bs)
for r in filter(!=(src), 0:H.commsize-1)
push!(reqs, MPI.Isend(tmp, H.comm; dest=r, tag=j + bs))
end
Expand All @@ -288,59 +319,64 @@ function householder!(H::MPIQRMatrix{T}, α=zeros(T, size(H, 2)); verbose=false,
next!(progress)
end
ts = (t1, t2, t3, t4, t5)
verbose && H.rank == 0 && @show (ts ./ sum(ts)..., sum(ts))
verbose && @show ts
return MPIQRStruct(H, α)
end


function solve_householder!(b, H, α; progress=FakeProgress(), verbose=false)
m, n = size(H)
bs = blocksize(H)
# multuply by Q' ...
b1 = zeros(eltype(b), length(b))
b2 = zeros(eltype(b), length(b))
b1 = similar(b, size(b))
s = similar(b, (1, size(b, 2)))
ta = tb = tc = td = te = 0.0
@inbounds @views for j in 1:bs:n
b1[j:m] .= 0
@inbounds for j in 1:bs:n
tb += @elapsed b1[j:m, :] .= 0
blockrank = columnowner(H, j)
if H.rank == blockrank
for jj in 0:bs-1
@assert columnowner(H, j) == blockrank
ta += @elapsed s = dot(H[j+jj:m, j+jj], b[j+jj:m])
tb += @elapsed b2[j+jj:m] .= H[j+jj:m, j+jj] .* s
tb += @elapsed b[j+jj:m] .-= b2[j+jj:m]
tb += @elapsed b1[j+jj:m] .+= b2[j+jj:m]
ta += @elapsed s .= view(H, j+jj:m, j+jj)' * view(b, j+jj:m, :)
tb += @elapsed view(b, j+jj:m, :) .-= view(H, j+jj:m, j+jj) * s
tb += @elapsed view(b1, j+jj:m, :) .+= view(H, j+jj:m, j+jj) * s
end
end
#tc += @elapsed MPI.Allreduce!(view(b1, j:m, :), +, H.comm)
tc += @elapsed MPI.Allreduce!(b1, +, H.comm)
if H.rank != blockrank
b[j:m] .-= b1[j:m]
tb += @elapsed b[j:m, :] .-= view(b1, j:m, :)
end
b1[j:j+bs-1] .= 0
end
# now that b holds the value of Q'b
# we may back sub with R
@inbounds @views for i in n:-1:1
bi = zero(eltype(b))
td += @elapsed @inbounds for j in intersect(H, i+1:n)
bi += H[i, j] * b[j]
jitervec = Vector{Int}(undef, localsize(H, 2))
td += @elapsed view(b, n, :) ./= view(α, n) # not iterating from n, but n-1
bi = similar(b, (1, size(b, 2)))
@inbounds for i in n-1:-1:1# can't assume @views with CuArray
bi .= 0
td += @elapsed jiter = intersect(H, i+1:n)
td += @elapsed @inbounds if !isempty(jiter)
jview = view(jiter)
if !isempty(jitervec)
# can't do transpose(view)
bi .+= sum(H[i, jview] .* b[jview, :], dims=1)
end
end
te += @elapsed bi = MPI.Allreduce(bi, +, H.comm)
b[i] -= bi
b[i] /= α[i]
te += @elapsed MPI.Allreduce!(bi, +, H.comm)
td += @elapsed view(b, i, :) .= (view(b, i, :) .- transpose(bi)) ./ view(α, i)
next!(progress)
end
ts = (ta, tb, tc, td, te)
verbose && H.rank == 0 && @show (ts ./ sum(ts)..., sum(ts))
return b[1:n]
verbose && @show ts
return b[1:n, :]
end

struct MPIQRStruct{T1, T2}
A::T1
α::T2
end

MPIQRStruct(A::MPIQRMatrix) = MPIQRStruct(A, zeros(eltype(A), size(A, 2)))
MPIQRStruct(A::MPIQRMatrix{T,M}) where {T,M} = MPIQRStruct(A, similar(A.localmatrix, size(A, 2)))

function LinearAlgebra.qr!(A::MPIQRMatrix; progress=FakeProgress(), verbose=false)
H = MPIQRStruct(A)
Expand Down
Loading