Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generalized dot product #32739

Merged
merged 35 commits into from
Sep 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3f69bd0
add generalized dot product
dkarrasch Jul 30, 2019
271781b
add generalized dot for Adjoint and Transpose
dkarrasch Jul 31, 2019
819c601
add "generalized" dot for UniformScalings
dkarrasch Jul 31, 2019
4ea06de
fix adjoint/transpose in tridiags
dkarrasch Jul 31, 2019
512c122
improve generic dot, add tests
dkarrasch Jul 31, 2019
b068ed1
fix typos, optimize *diag, require_one_based_indexing
dkarrasch Aug 1, 2019
96b3a8b
add tests
dkarrasch Aug 1, 2019
6012d59
fix typos in triangular and tridiag
dkarrasch Aug 1, 2019
430320b
fix BigFloat tests in triangular
dkarrasch Aug 1, 2019
5008546
add sparse tests (and minor fix)
dkarrasch Aug 1, 2019
a28280f
handle block arrays of varying lengths
dkarrasch Aug 6, 2019
93bb8c7
make generalized dot act recursively
dkarrasch Aug 6, 2019
6aea6bb
add generalized dot for symmetric/Hermitian matrices
dkarrasch Aug 6, 2019
4855aa5
fix triangular case
dkarrasch Aug 7, 2019
427b849
more complete tests for Symmetric/Hermitian
dkarrasch Aug 7, 2019
93b59be
fix UnitLowerTriangular case
dkarrasch Aug 7, 2019
791be8b
fix complex case in symmetric gendot
dkarrasch Aug 7, 2019
5d6bbbb
interpret dot(x, A, y) as dot(A'x, y), test accordingly
dkarrasch Aug 7, 2019
c176bfc
use correct tolerance in triangular tests
dkarrasch Aug 7, 2019
9510dfd
add gendot for UpperHessenberg, and tests
dkarrasch Aug 8, 2019
a6bbb45
fix docstring of 3-arg dot
dkarrasch Aug 16, 2019
e4668ec
add generic 3-arg dot for UniformScaling
dkarrasch Aug 16, 2019
0932238
add generic fallback
dkarrasch Aug 16, 2019
6310fa6
add gendot with middle argument Number
dkarrasch Aug 16, 2019
a66a6a1
merge NEWS
dkarrasch Aug 28, 2019
fb97cc3
attach docstring to generic fallback
dkarrasch Aug 18, 2019
c9112fc
simplify scalar/uniform scaling gendot
dkarrasch Aug 18, 2019
cffa4aa
use dot(A'x,y) for fallback
dkarrasch Aug 29, 2019
48bdbc1
use accessor functions in sparse code, generalize to Abstract..., tests
dkarrasch Aug 30, 2019
ff377e6
revert fallback definition
dkarrasch Aug 30, 2019
48874e2
add compat note and jldoctest
dkarrasch Aug 31, 2019
5568795
remove redundant Number version
dkarrasch Aug 30, 2019
8051614
write out loops in symmetric/hermitian case
dkarrasch Aug 30, 2019
200732c
test quaternions in uniformscaling gendot
dkarrasch Aug 30, 2019
d304264
fix uniformscaling test
dkarrasch Aug 30, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Standard library changes

* `qr` and `qr!` functions support `blocksize` keyword argument ([#33053]).

* `dot` now admits a 3-argument method `dot(x, A, y)` to compute generalized dot products `dot(x, A*y)`, but without computing and storing the intermediate result `A*y` ([#32739]).

#### SparseArrays

Expand Down
30 changes: 30 additions & 0 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,36 @@ function *(A::SymTridiagonal, B::Diagonal)
A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B)
end

function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
require_one_based_indexing(x, y)
nx, ny = length(x), length(y)
(nx == size(B, 1) == ny) || throw(DimensionMismatch())
if iszero(nx)
return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y)))
end
ev, dv = B.ev, B.dv
if B.uplo == 'U'
x₀ = x[1]
r = dot(x[1], dv[1], y[1])
@inbounds for j in 2:nx-1
x₋, x₀ = x₀, x[j]
r += dot(adjoint(ev[j-1])*x₋ + adjoint(dv[j])*x₀, y[j])
end
r += dot(adjoint(ev[nx-1])*x₀ + adjoint(dv[nx])*x[nx], y[nx])
return r
else # B.uplo == 'L'
x₀ = x[1]
x₊ = x[2]
r = dot(adjoint(dv[1])*x₀ + adjoint(ev[1])*x₊, y[1])
@inbounds for j in 2:nx-1
x₀, x₊ = x₊, x[j+1]
r += dot(adjoint(dv[j])*x₀ + adjoint(ev[j])*x₊, y[j])
end
r += dot(x₊, dv[nx], y[nx])
return r
end
end

#Linear solvers
ldiv!(A::Union{Bidiagonal, AbstractTriangular}, b::AbstractVector) = naivesub!(A, b)
ldiv!(A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVector) = ldiv!(copy(A), b)
Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,14 @@ end

# disambiguation methods: * of Diagonal and Adj/Trans AbsVec
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
function dot(x::AbstractVector, D::Diagonal, y::AbstractVector)
mapreduce(t -> dot(t[1], t[2], t[3]), +, zip(x, D.diag, y))
end

function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
info = 0
Expand Down
45 changes: 45 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,51 @@ function dot(x::AbstractArray, y::AbstractArray)
s
end

"""
dot(x, A, y)

Compute the generalized dot product `dot(x, A*y)` between two vectors `x` and `y`,
without storing the intermediate result of `A*y`. As for the two-argument
[`dot(_,_)`](@ref), this acts recursively. Moreover, for complex vectors, the
first vector is conjugated.

!!! compat "Julia 1.4"
Three-argument `dot` requires at least Julia 1.4.

# Examples
```jldoctest
julia> dot([1; 1], [1 2; 3 4], [2; 3])
26

julia> dot(1:5, reshape(1:25, 5, 5), 2:6)
4850

julia> ⋅(1:5, reshape(1:25, 5, 5), 2:6) == dot(1:5, reshape(1:25, 5, 5), 2:6)
true
```
"""
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
dot(x, A, y) = dot(x, A*y) # generic fallback for cases that are not covered by specialized methods

function dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector)
(axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch())
T = typeof(dot(first(x), first(A), first(y)))
s = zero(T)
i₁ = first(eachindex(x))
x₁ = first(x)
@inbounds for j in eachindex(y)
yj = y[j]
if !iszero(yj)
temp = zero(adjoint(A[i₁,j]) * x₁)
@simd for i in eachindex(x)
temp += adjoint(A[i,j]) * x[i]
end
s += dot(temp, yj)
end
end
return s
end
dot(x::AbstractVector, adjA::Adjoint, y::AbstractVector) = adjoint(dot(y, adjA.parent, x))
dot(x::AbstractVector, transA::Transpose{<:Real}, y::AbstractVector) = adjoint(dot(y, transA.parent, x))

###########################################################################################

Expand Down
31 changes: 31 additions & 0 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,37 @@ function logabsdet(F::UpperHessenberg; shift::Number=false)
return (logdeterminant, P)
end

function dot(x::AbstractVector, H::UpperHessenberg, y::AbstractVector)
require_one_based_indexing(x, y)
m = size(H, 1)
(length(x) == m == length(y)) || throw(DimensionMismatch())
if iszero(m)
return dot(zero(eltype(x)), zero(eltype(H)), zero(eltype(y)))
end
x₁ = x[1]
r = dot(x₁, H[1,1], y[1])
r += dot(x[2], H[2,1], y[1])
@inbounds for j in 2:m-1
yj = y[j]
if !iszero(yj)
temp = adjoint(H[1,j]) * x₁
@simd for i in 2:j+1
temp += adjoint(H[i,j]) * x[i]
end
r += dot(temp, yj)
end
end
ym = y[m]
if !iszero(ym)
temp = adjoint(H[1,m]) * x₁
@simd for i in 2:m
temp += adjoint(H[i,m]) * x[i]
end
r += dot(temp, ym)
end
return r
end

######################################################################################
# Hessenberg factorizations Q(H+μI)Q' of A+μI:

Expand Down
25 changes: 25 additions & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,31 @@ end

*(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B)

function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
require_one_based_indexing(x, y)
(length(x) == length(y) == size(A, 1)) || throw(DimensionMismatch())
data = A.data
r = zero(eltype(x)) * zero(eltype(A)) * zero(eltype(y))
if A.uplo == 'U'
@inbounds for j = 1:length(y)
r += dot(x[j], real(data[j,j]), y[j])
@simd for i = 1:j-1
Aij = data[i,j]
r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i])
end
end
else # A.uplo == 'L'
@inbounds for j = 1:length(y)
r += dot(x[j], real(data[j,j]), y[j])
@simd for i = j+1:length(y)
Aij = data[i,j]
r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i])
end
end
end
return r
end

# Fallbacks to avoid generic_matvecmul!/generic_matmatmul!
## Symmetric{<:Number} and Hermitian{<:Real} are invariant to transpose; peel off the t
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, B::AbstractVector) = transA.parent * B
Expand Down
84 changes: 84 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,90 @@ end
rmul!(A::Union{UpperTriangular,LowerTriangular}, c::Number) = mul!(A, A, c)
lmul!(c::Number, A::Union{UpperTriangular,LowerTriangular}) = mul!(A, c, A)

function dot(x::AbstractVector, A::UpperTriangular, y::AbstractVector)
require_one_based_indexing(x, y)
m = size(A, 1)
(length(x) == m == length(y)) || throw(DimensionMismatch())
if iszero(m)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
x₁ = x[1]
r = dot(x₁, A[1,1], y[1])
@inbounds for j in 2:m
yj = y[j]
if !iszero(yj)
temp = adjoint(A[1,j]) * x₁
@simd for i in 2:j
temp += adjoint(A[i,j]) * x[i]
end
r += dot(temp, yj)
end
end
return r
end
function dot(x::AbstractVector, A::UnitUpperTriangular, y::AbstractVector)
require_one_based_indexing(x, y)
m = size(A, 1)
(length(x) == m == length(y)) || throw(DimensionMismatch())
if iszero(m)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
x₁ = first(x)
r = dot(x₁, y[1])
@inbounds for j in 2:m
yj = y[j]
if !iszero(yj)
temp = adjoint(A[1,j]) * x₁
@simd for i in 2:j-1
temp += adjoint(A[i,j]) * x[i]
end
r += dot(temp, yj)
r += dot(x[j], yj)
end
end
return r
end
function dot(x::AbstractVector, A::LowerTriangular, y::AbstractVector)
require_one_based_indexing(x, y)
m = size(A, 1)
(length(x) == m == length(y)) || throw(DimensionMismatch())
if iszero(m)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
r = zero(typeof(dot(first(x), first(A), first(y))))
@inbounds for j in 1:m
yj = y[j]
if !iszero(yj)
temp = adjoint(A[j,j]) * x[j]
@simd for i in j+1:m
temp += adjoint(A[i,j]) * x[i]
end
r += dot(temp, yj)
end
end
return r
end
function dot(x::AbstractVector, A::UnitLowerTriangular, y::AbstractVector)
require_one_based_indexing(x, y)
m = size(A, 1)
(length(x) == m == length(y)) || throw(DimensionMismatch())
if iszero(m)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
r = zero(typeof(dot(first(x), first(y))))
@inbounds for j in 1:m
yj = y[j]
if !iszero(yj)
temp = x[j]
@simd for i in j+1:m
temp += adjoint(A[i,j]) * x[i]
end
r += dot(temp, yj)
end
end
return r
end

fillstored!(A::LowerTriangular, x) = (fillband!(A.data, x, 1-size(A,1), 0); A)
fillstored!(A::UnitLowerTriangular, x) = (fillband!(A.data, x, 1-size(A,1), -1); A)
fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1); A)
Expand Down
40 changes: 40 additions & 0 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,27 @@ end
return C
end

function dot(x::AbstractVector, S::SymTridiagonal, y::AbstractVector)
require_one_based_indexing(x, y)
nx, ny = length(x), length(y)
(nx == size(S, 1) == ny) || throw(DimensionMismatch())
if iszero(nx)
return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y)))
end
dv, ev = S.dv, S.ev
x₀ = x[1]
x₊ = x[2]
sub = transpose(ev[1])
r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1])
@inbounds for j in 2:nx-1
x₋, x₀, x₊ = x₀, x₊, x[j+1]
sup, sub = transpose(sub), transpose(ev[j])
r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j])
end
r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx])
return r
end

(\)(T::SymTridiagonal, B::StridedVecOrMat) = ldlt(T)\B

# division with optional shift for use in shifted-Hessenberg solvers (hessenberg.jl):
Expand Down Expand Up @@ -657,3 +678,22 @@ end

Base._sum(A::Tridiagonal, ::Colon) = sum(A.d) + sum(A.dl) + sum(A.du)
Base._sum(A::SymTridiagonal, ::Colon) = sum(A.dv) + 2sum(A.ev)

function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector)
require_one_based_indexing(x, y)
nx, ny = length(x), length(y)
(nx == size(A, 1) == ny) || throw(DimensionMismatch())
if iszero(nx)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
x₀ = x[1]
x₊ = x[2]
dl, d, du = A.dl, A.d, A.du
r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1])
@inbounds for j in 2:nx-1
x₋, x₀, x₊ = x₀, x₊, x[j+1]
r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j])
end
r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx])
return r
end
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,7 @@ Array(s::UniformScaling, dims::Dims{2}) = Matrix(s, dims)
## Diagonal construction from UniformScaling
Diagonal{T}(s::UniformScaling, m::Integer) where {T} = Diagonal{T}(fill(T(s.λ), m))
Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m)

dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) = dot(x, J.λ, y)
dot(x::AbstractVector, a::Number, y::AbstractVector) = sum(t -> dot(t[1], a, t[2]), zip(x, y))
dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) = a*dot(x, y)
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,4 +455,17 @@ end
@test A * Tridiagonal(ones(1, 1)) == A
end

@testset "generalized dot" begin
for elty in (Float64, ComplexF64)
dv = randn(elty, 5)
ev = randn(elty, 4)
x = randn(elty, 5)
y = randn(elty, 5)
for uplo in (:U, :L)
B = Bidiagonal(dv, ev, uplo)
@test dot(x, B, y) ≈ dot(B'x, y) ≈ dot(x, Matrix(B), y)
end
end
end

end # module TestBidiagonal
28 changes: 28 additions & 0 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,4 +409,32 @@ end
@test all(!isnan, lmul!(false, Any[NaN]))
end

@testset "generalized dot #32739" begin
for elty in (Int, Float32, Float64, BigFloat, Complex{Float32}, Complex{Float64}, Complex{BigFloat})
n = 10
if elty <: Int
A = rand(-n:n, n, n)
x = rand(-n:n, n)
y = rand(-n:n, n)
elseif elty <: Real
A = convert(Matrix{elty}, randn(n,n))
x = rand(elty, n)
y = rand(elty, n)
else
A = convert(Matrix{elty}, complex.(randn(n,n), randn(n,n)))
x = rand(elty, n)
y = rand(elty, n)
end
@test dot(x, A, y) ≈ dot(A'x, y) ≈ *(x', A, y) ≈ (x'A)*y
@test dot(x, A', y) ≈ dot(A*x, y) ≈ *(x', A', y) ≈ (x'A')*y
elty <: Real && @test dot(x, transpose(A), y) ≈ dot(x, transpose(A)*y) ≈ *(x', transpose(A), y) ≈ (x'*transpose(A))*y
B = reshape([A], 1, 1)
x = [x]
y = [y]
@test dot(x, B, y) ≈ dot(B'x, y)
@test dot(x, B', y) ≈ dot(B*x, y)
elty <: Real && @test dot(x, transpose(B), y) ≈ dot(x, transpose(B)*y)
end
end

end # module TestGeneric
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/test/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ let n = 10
@test det(H + shift*I) ≈ det(A + shift*I)
@test logabsdet(H + shift*I) ≅ logabsdet(A + shift*I)
end

HM = Matrix(h)
@test dot(b, h, b) ≈ dot(h'b, b) ≈ dot(b, HM, b) ≈ dot(HM'b, b)
c = b .+ 1
@test dot(b, h, c) ≈ dot(h'b, c) ≈ dot(b, HM, c) ≈ dot(HM'b, c)
end
end

Expand Down
Loading