Skip to content

Commit

Permalink
Improved output type inference for matrix products
Browse files Browse the repository at this point in the history
  • Loading branch information
Andy Ferris committed Oct 25, 2016
1 parent 9989c89 commit a310bf5
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ import Base: A_mul_B!, Ac_mul_B!, A_mul_Bc!, Ac_mul_Bc!, At_mul_B!, A_mul_Bt!, A

typealias BlasEltypes Union{Float64, Float32, Complex{Float64}, Complex{Float32}}

# Stolen from https://github.com/JuliaLang/julia/pull/18218
matprod(x,y) = x*y + x*y;

# TODO size-inferrable products with AbstractArray (such as StaticMatrix * AbstractVector)
# TODO Potentially a loop version for rather large arrays? Or try and figure out inference problems?


# TODO make faster versions of A*_mul_B*
@generated function A_mul_Bc(A::Union{StaticMatrix, StaticVector}, B::Union{StaticMatrix, StaticVector})
return quote
Expand Down Expand Up @@ -94,7 +96,7 @@ end
sb = size(b)

s = (sA[1],)
T = typeof(zero(TA)*zero(Tb))
T = promote_op(matprod, TA, Tb)

if sb[1] != sA[2]
error("Dimension mismatch")
Expand Down Expand Up @@ -133,7 +135,7 @@ end
sB = size(B)

s = (sa[1],sB[2])
T = typeof(zero(Ta)*zero(TB))
T = promote_op(matprod, Ta, TB)

if sB[1] != 1
error("Dimension mismatch")
Expand Down Expand Up @@ -167,7 +169,7 @@ end
TA = eltype(A)
TB = eltype(B)

T = typeof(zero(TA)*zero(TB))
T = promote_op(matprod, TA, TB)

can_mutate = !isbits(A) || !isbits(B) # !isbits implies can get a persistent pointer (to pass to BLAS). Probably will change to !isimmutable in a future version of Julia.
can_blas = T == TA && T == TB && T <: Union{Float64, Float32, Complex{Float64}, Complex{Float32}}
Expand Down Expand Up @@ -229,7 +231,7 @@ end
TB = eltype(B)

s = (sA[1], sB[2])
T = typeof(zero(TA)*zero(TB))
T = promote_op(matprod, TA, TB)

if sB[1] != sA[2]
error("Dimension mismatch")
Expand Down Expand Up @@ -270,7 +272,7 @@ end
TB = eltype(B)

s = (sA[1], sB[2])
T = typeof(zero(TA)*zero(TB))
T = promote_op(matprod, TA, TB)

if sB[1] != sA[2]
error("Dimension mismatch")
Expand Down Expand Up @@ -315,7 +317,7 @@ end
TB = eltype(B)

s = (sA[1], sB[2])
T = typeof(zero(TA)*zero(TB))
T = promote_op(matprod, TA, TB)

if sB[1] != sA[2]
error("Dimension mismatch")
Expand Down Expand Up @@ -359,7 +361,7 @@ end
sb = size(b)

s = (sA[1],)
T = typeof(zero(TA)*zero(Tb))
T = promote_op(matprod, TA, Tb)

if sb[1] != sA[2]
error("Dimension mismatch")
Expand Down Expand Up @@ -417,7 +419,7 @@ end

TA = eltype(A)
TB = eltype(B)
T = typeof(zero(TA)*zero(TB))
T = promote_op(matprod, TA, TB)

can_blas = T == TA && T == TB && T <: Union{Float64, Float32, Complex{Float64}, Complex{Float32}}

Expand Down

0 comments on commit a310bf5

Please sign in to comment.