Skip to content

Commit

Permalink
[SparseArrays] Respect order of mul in (l)mul!(::Diagonal,::Sparse) (#…
Browse files Browse the repository at this point in the history
…30163)

* order of mul in (l)mul!(::Diagonal,::Sparse)

add tests for non-commutative mul

* Create Quaternions.jl

remove quaternions from generic tests
  • Loading branch information
dkarrasch authored and andreasnoack committed Dec 13, 2018
1 parent 49023b5 commit 3ee8798
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 30 deletions.
31 changes: 4 additions & 27 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,10 @@
module TestGeneric

using Test, LinearAlgebra, Random
import Base: -, *, /, \

# A custom Quaternion type with minimal defined interface and methods.
# Used to test mul and mul! methods to show non-commutativity.
struct Quaternion{T<:Real} <: Number
s::T
v1::T
v2::T
v3::T
end
Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...)
Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3
Base.abs(q::Quaternion) = sqrt(abs2(q))
Base.real(::Type{Quaternion{T}}) where {T} = T
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3)

(-)(ql::Quaternion, qr::Quaternion) =
Quaternion(ql.s - qr.s, ql.v1 - qr.v1, ql.v2 - qr.v2, ql.v3 - qr.v3)
(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3,
q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2,
q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1,
q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s)
(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r)
(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity
(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w))
(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q))

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

Random.seed!(123)

Expand Down
6 changes: 3 additions & 3 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ function mul!(C::SparseMatrixCSC, D::Diagonal{T, <:Vector}, A::SparseMatrixCSC)
Arowval = A.rowval
resize!(Cnzval, length(Anzval))
for col = 1:n, p = A.colptr[col]:(A.colptr[col+1]-1)
@inbounds Cnzval[p] = Anzval[p] * b[Arowval[p]]
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
end
C
end
Expand Down Expand Up @@ -1238,7 +1238,7 @@ function rmul!(A::SparseMatrixCSC, D::Diagonal)
(n == size(D, 1)) || throw(DimensionMismatch())
Anzval = A.nzval
@inbounds for col = 1:n, p = A.colptr[col]:(A.colptr[col + 1] - 1)
Anzval[p] *= D.diag[col]
Anzval[p] = Anzval[p] * D.diag[col]
end
return A
end
Expand All @@ -1249,7 +1249,7 @@ function lmul!(D::Diagonal, A::SparseMatrixCSC)
Anzval = A.nzval
Arowval = A.rowval
@inbounds for col = 1:n, p = A.colptr[col]:(A.colptr[col + 1] - 1)
Anzval[p] *= D.diag[Arowval[p]]
Anzval[p] = D.diag[Arowval[p]] * Anzval[p]
end
return A
end
Expand Down
24 changes: 24 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ end
@test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2))
end

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

sA = sprandn(3, 7, 0.5)
sC = similar(sA)
dA = Array(sA)
Expand Down Expand Up @@ -403,6 +407,26 @@ dA = Array(sA)
@test_throws DimensionMismatch rdiv!(copy(sAt), Diagonal(fill(1., length(b)+1)))
@test_throws LinearAlgebra.SingularException rdiv!(copy(sAt), Diagonal(zeros(length(b))))
end

@testset "non-commutative multiplication" begin
# non-commutative multiplication
Avals = Quaternion.(randn(10), randn(10), randn(10), randn(10))
sA = sparse(rand(1:3, 10), rand(1:7, 10), Avals, 3, 7)
sC = copy(sA)
dA = Array(sA)

b = Quaternion.(randn(7), randn(7), randn(7), randn(7))
D = Diagonal(b)
@test Array(sA * D) dA * D
@test rmul!(copy(sA), D) dA * D
@test mul!(sC, copy(sA), D) dA * D

b = Quaternion.(randn(3), randn(3), randn(3), randn(3))
D = Diagonal(b)
@test Array(D * sA) D * dA
@test lmul!(D, copy(sA)) D * dA
@test mul!(sC, D, copy(sA)) D * dA
end
end

@testset "copyto!" begin
Expand Down
34 changes: 34 additions & 0 deletions test/testhelpers/Quaternions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module Quaternions

export Quaternion

# A custom Quaternion type with minimal defined interface and methods.
# Used to test mul and mul! methods to show non-commutativity.
struct Quaternion{T<:Real} <: Number
s::T
v1::T
v2::T
v3::T
end
Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...)
Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3
Base.abs(q::Quaternion) = sqrt(abs2(q))
Base.real(::Type{Quaternion{T}}) where {T} = T
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3)
Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T))

Base.:(+)(ql::Quaternion, qr::Quaternion) =
Quaternion(ql.s + qr.s, ql.v1 + qr.v1, ql.v2 + qr.v2, ql.v3 + qr.v3)
Base.:(-)(ql::Quaternion, qr::Quaternion) =
Quaternion(ql.s - qr.s, ql.v1 - qr.v1, ql.v2 - qr.v2, ql.v3 - qr.v3)
Base.:(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3,
q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2,
q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1,
q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s)
Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r)
Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity
Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w))
Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q))

end

0 comments on commit 3ee8798

Please sign in to comment.