Skip to content

Commit

Permalink
Merge pull request #26 from invenia/npr/perf
Browse files Browse the repository at this point in the history
Matmul performance improvements
  • Loading branch information
nickrobinson251 authored Oct 22, 2019
2 parents 3bd33a3 + 829034c commit 85405f4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/BlockDiagonals.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module BlockDiagonals

using Base: @propagate_inbounds
using FillArrays: Zeros
using LinearAlgebra

Expand Down
38 changes: 22 additions & 16 deletions src/base_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,39 +83,45 @@ function _check_matmul_dims(A::AbstractMatrix, B::AbstractMatrix)
))
end

function Base.:*(B::BlockDiagonal, M::AbstractMatrix)
function Base.:*(B::BlockDiagonal{T}, M::AbstractMatrix) where T
_check_matmul_dims(B, M)
bblocks = blocks(B)
new_blocksizes = zip(size.(bblocks, 1), fill(size(M, 2), length(bblocks)))
d = similar.(bblocks, T, new_blocksizes)
ed = 0
d = map(blocks(B)) do block
@inbounds @views for (p, block) in enumerate(bblocks)
st = ed + 1 # start
ed += size(block, 2) # end
return block * M[st:ed, :]
mul!(d[p], block, M[st:ed, :])
end
return reduce(vcat, d)::Matrix
return reduce(vcat, d)
end

function Base.:*(M::AbstractMatrix, B::BlockDiagonal)
function Base.:*(M::AbstractMatrix, B::BlockDiagonal{T}) where T
_check_matmul_dims(M, B)
bblocks = blocks(B)
new_blocksizes = zip(fill(size(M, 1), length(bblocks)), size.(bblocks, 2))
d = similar.(bblocks, T, new_blocksizes)
ed = 0
d = map(blocks(B)) do block
@inbounds @views for (p, block) in enumerate(bblocks)
st = ed + 1 # start
ed += size(block, 1) # end
return M[:, st:ed] * block
mul!(d[p], M[:, st:ed], block)
end
return reduce(hcat, d)::Matrix
return reduce(hcat, d)
end

# Diagonal
function Base.:*(B::BlockDiagonal, M::Diagonal)::BlockDiagonal
_check_matmul_dims(B, M)
A = copy(B)
d = diag(M)
A = similar(B)
d = parent(M)
col = 1
for (p, block) in enumerate(blocks(B))
@inbounds @views for (p, block) in enumerate(blocks(B))
ncols = size(block, 2)
cols = col:(col + ncols-1)
for (j, c) in enumerate(cols)
getblock(A, p)[:, j] *= d[c]
mul!(getblock(A, p)[:, j], block[:, j], d[c])
end
col += ncols
end
Expand All @@ -124,14 +130,14 @@ end

function Base.:*(M::Diagonal, B::BlockDiagonal)::BlockDiagonal
_check_matmul_dims(M, B)
A = copy(B)
d = diag(M)
A = similar(B)
d = parent(M)
row = 1
for (p, block) in enumerate(blocks(B))
@inbounds @views for (p, block) in enumerate(blocks(B))
nrows = size(block, 1)
rows = row:(row + nrows-1)
for (i, r) in enumerate(rows)
getblock(A, p)[i, :] *= d[r]
mul!(getblock(A, p)[i, :], block[i, :], d[r])
end
row += nrows
end
Expand Down
6 changes: 3 additions & 3 deletions src/blockdiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Base.size(B::BlockDiagonal) = sum(first∘size, blocks(B)), sum(last∘size, blo
Base.similar(B::BlockDiagonal) = BlockDiagonal(map(similar, blocks(B)))
Base.parent(B::BlockDiagonal) = B.blocks

function Base.setindex!(B::BlockDiagonal{T}, v, i::Integer, j::Integer) where T
@propagate_inbounds function Base.setindex!(B::BlockDiagonal, v, i::Integer, j::Integer)
p, i_, j_ = _block_indices(B, i, j)
if p > 0
@inbounds getblock(B, p)[i_, end+j_] = v
Expand All @@ -118,10 +118,10 @@ function Base.setindex!(B::BlockDiagonal{T}, v, i::Integer, j::Integer) where T
return v
end

function Base.getindex(B::BlockDiagonal{T}, i::Integer, j::Integer) where T
@propagate_inbounds function Base.getindex(B::BlockDiagonal{T}, i::Integer, j::Integer) where T
p, i, j = _block_indices(B, i, j)
# if not in on-diagonal block `p` then value at `i, j` must be zero
return p > 0 ? getblock(B, p)[i, end + j] : zero(T)
@inbounds return p > 0 ? getblock(B, p)[i, end + j] : zero(T)
end

# Transform indices `i, j` (identifying entry `Matrix(B)[i, j]`) into indices `p, i, j` such
Expand Down
8 changes: 8 additions & 0 deletions test/base_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ using Test
end # Addition

@testset "Multiplication" begin

@testset "BlockDiagonal * BlockDiagonal" begin
@test b1 * b1 isa BlockDiagonal
@test Matrix(b1 * b1) Matrix(b1) * Matrix(b1)
Expand Down Expand Up @@ -96,6 +97,13 @@ using Test
@test A′ * b1 isa Matrix
@test A′ * b1 A′ * Matrix(b1)
@test_throws DimensionMismatch A * b1

# degenerate cases
m = rand(0, 0)
@test m * BlockDiagonal([m]) == m * m == m
m = rand(5, 0)
@test m' * BlockDiagonal([m]) == m' * m == rand(0, 0)
@test m * BlockDiagonal([m']) == m * m' == zeros(5, 5)
end

@testset "BlockDiagonal * Diagonal" begin
Expand Down

2 comments on commit 85405f4

@nickrobinson251
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/4617

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.3 -m "<description of version>" 85405f4fcc54a5a3a5aa5063392eaabe8c96bba6
git push origin v0.1.3

Please sign in to comment.