Skip to content

Commit

Permalink
Merge pull request #324 from JuliaArrays/fix-nonsquare-matrix-mult
Browse files Browse the repository at this point in the history
Fix sizes in non-square medium sized matrix multiply
  • Loading branch information
c42f authored Oct 20, 2017
2 parents 8aacd2b + cae40b8 commit 11f371d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,11 @@ end
S = Size(sa[1], sb[2])

# Do a custom b[:, k2] to return a SVector (an isbits type) rather than (possibly) a mutable type. Avoids allocation == faster
tmp_type_in = :(SVector{$(sa[1]), T})
tmp_type_out = :(SVector{$(sb[1]), T})
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(Size(a), Size($(sa[1])), a, $(Expr(:call, tmp_type_in, [Expr(:ref, :b, sub2ind(S, i, k2)) for i = 1:sb[1]]...)))::$tmp_type_out) for k2 = 1:sb[2]]
tmp_type_in = :(SVector{$(sb[1]), T})
tmp_type_out = :(SVector{$(sa[1]), T})
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(Size(a), Size($(sb[1])), a,
$(Expr(:call, tmp_type_in, [Expr(:ref, :b, sub2ind(sb, i, k2)) for i = 1:sb[1]]...)))::$tmp_type_out)
for k2 = 1:sb[2]]

exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]

Expand Down
8 changes: 8 additions & 0 deletions test/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@
n2 = SMatrix{16,16}(n_array2)
@test m2*n2 === SMatrix{16,16}(a_array2)

# Non-square version
m_array3 = rand(1:10, 9, 10)
n_array3 = rand(1:10, 10, 11)
a_array3 = m_array3*n_array3
m3 = SMatrix{9,10}(m_array3)
n3 = SMatrix{10,11}(n_array3)
@test m3*n3 === SMatrix{9,11}(a_array3)

# Mutating types follow different behaviour
m_array = rand(1:10, 10, 10)
n_array = rand(1:10, 10, 10)
Expand Down

0 comments on commit 11f371d

Please sign in to comment.