diff --git a/stdlib/LinearAlgebra/src/bitarray.jl b/stdlib/LinearAlgebra/src/bitarray.jl index d1857c3c38659..2f8bce8bae083 100644 --- a/stdlib/LinearAlgebra/src/bitarray.jl +++ b/stdlib/LinearAlgebra/src/bitarray.jl @@ -2,7 +2,9 @@ function dot(x::BitVector, y::BitVector) # simplest way to mimic Array dot behavior - length(x) == length(y) || throw(DimensionMismatch()) + if axes(x) != axes(y) + throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) + end s = 0 xc = x.chunks yc = y.chunks diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 661e9e2b15617..a329cbefdf063 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -353,71 +353,27 @@ for (fname, elty) in ((:cblas_zdotu_sub,:ComplexF64), end end -@inline function _dot_length_check(x,y) - n = length(x) - if n != length(y) - throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))")) - end - n -end - -for (elty, f) in ((Float32, :dot), (Float64, :dot), - (ComplexF32, :dotc), (ComplexF64, :dotc), - (ComplexF32, :dotu), (ComplexF64, :dotu)) - @eval begin - function $f(x::DenseArray{$elty}, y::DenseArray{$elty}) - n = _dot_length_check(x,y) - $f(n, x, 1, y, 1) - end - - function $f(x::StridedVector{$elty}, y::DenseArray{$elty}) - n = _dot_length_check(x,y) - xstride = stride(x,1) - ystride = stride(y,1) - x_delta = xstride < 0 ? n : 1 - GC.@preserve x $f(n,pointer(x,x_delta),xstride,y,ystride) - end - - function $f(x::DenseArray{$elty}, y::StridedVector{$elty}) - n = _dot_length_check(x,y) - xstride = stride(x,1) - ystride = stride(y,1) - y_delta = ystride < 0 ? n : 1 - GC.@preserve y $f(n,x,xstride,pointer(y,y_delta),ystride) - end - - function $f(x::StridedVector{$elty}, y::StridedVector{$elty}) - n = _dot_length_check(x,y) - xstride = stride(x,1) - ystride = stride(y,1) - x_delta = xstride < 0 ? n : 1 - y_delta = ystride < 0 ? n : 1 - GC.@preserve x y $f(n,pointer(x,x_delta),xstride,pointer(y,y_delta),ystride) - end - end -end - function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal require_one_based_indexing(DX, DY) n = length(DX) - if n != length(DY) - throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))")) + if axes(DX) != axes(DY) + throw(DimensionMismatch("The first array has axes $(axes(DX)) that do not match the axes of the second, $(axes(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) end return dot(n, DX, stride(DX, 1), DY, stride(DY, 1)) end function dotc(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex require_one_based_indexing(DX, DY) n = length(DX) - if n != length(DY) - throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))")) + if axes(DX) != axes(DY) + throw(DimensionMismatch("The first array has axes $(axes(DX)) that do not match the axes of the second, $(axes(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) end return dotc(n, DX, stride(DX, 1), DY, stride(DY, 1)) end function dotu(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex require_one_based_indexing(DX, DY) n = length(DX) - if n != length(DY) - throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))")) + if axes(DX) != axes(DY) + throw(DimensionMismatch("The first array has axes $(axes(DX)) that do not match the axes of the second, $(axes(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) end return dotu(n, DX, stride(DX, 1), DY, stride(DY, 1)) end diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index cf7e474468785..cd3cbdb2a2e51 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -904,12 +904,30 @@ end dot(x::Number, y::Number) = conj(x) * y +""" + dot(x, y) + x ⋅ y + +Compute the dot product between two arrays with the same [`axes`](@ref) as if they +were vectors. For complex arrays, the elements of the first array are conjugated. +This is the classical dot product for vectors and the Hilbert-Schmidt dot +product `tr(x' * y)` for matrices. When the arrays have equal axes, calling +`dot` is semantically equivalent to `sum(dot(vx,vy) for (vx,vy) in zip(x, y))`. + +# Examples +```jldoctest +julia> dot([1; 1], [2; 3]) +5 + +julia> dot([im; im], [1; 1]) +0 - 2im +``` +""" function dot(x::AbstractArray, y::AbstractArray) - lx = length(x) - if lx != length(y) - throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y)).")) + if axes(x) != axes(y) + throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) end - if lx == 0 + if length(x) == 0 return dot(zero(eltype(x)), zero(eltype(y))) end s = zero(dot(first(x), first(y))) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 0cbfeaf9ed3ea..7e1aeff2bb8cf 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -6,9 +6,6 @@ matprod(x, y) = x*y + x*y # dot products -dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasReal} = BLAS.dot(x, y) -dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasComplex} = BLAS.dotc(x, y) - function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasReal,TI<:Integer} if length(rx) != length(ry) throw(DimensionMismatch("length of rx, $(length(rx)), does not equal length of ry, $(length(ry))")) diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index df29c171b2060..413739e6779c5 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -60,6 +60,10 @@ Random.seed!(100) x2 = convert(Vector{elty}, randn(n)) @test BLAS.dot(x1,x2) ≈ sum(x1.*x2) @test_throws DimensionMismatch BLAS.dot(x1,rand(elty, n + 1)) + y1 = convert(Matrix{elty}, randn(4,4)) + y2 = convert(Matrix{elty}, randn(2,8)) + @test_throws DimensionMismatch BLAS.dot(y1, y2) + @test sum(y1[i] * y2[i] for i in 1:16) ≈ BLAS.dot(vec(y1), vec(y2)) else z1 = convert(Vector{elty}, complex.(randn(n),randn(n))) z2 = convert(Vector{elty}, complex.(randn(n),randn(n))) @@ -67,6 +71,12 @@ Random.seed!(100) @test BLAS.dotu(z1,z2) ≈ sum(z1.*z2) @test_throws DimensionMismatch BLAS.dotc(z1,rand(elty, n + 1)) @test_throws DimensionMismatch BLAS.dotu(z1,rand(elty, n + 1)) + y1 = convert(Matrix{elty}, complex.(randn(4,4),randn(4,4))) + y2 = convert(Matrix{elty}, complex.(randn(2,8),randn(2,8))) + @test_throws DimensionMismatch BLAS.dotc(y1, y2) + @test_throws DimensionMismatch BLAS.dotu(y1, y2) + @test sum(conj(y1[i]) * y2[i] for i in 1:16) ≈ BLAS.dotc(vec(y1), vec(y2)) + @test sum(y1[i] * y2[i] for i in 1:16) ≈ BLAS.dotu(vec(y1), vec(y2)) end end @testset "iamax" begin diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 1017134f2f6d4..51c6483404fda 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -433,6 +433,10 @@ end @test dot(X, Y) == convert(elty, 35.0) Z = convert(Vector{Matrix{elty}},[reshape(1:4, 2, 2), fill(1, 2, 2)]) @test dot(Z, Z) == convert(elty, 34.0) + Y2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5]) + @test_throws DimensionMismatch dot(X, Y2) + @test_throws DimensionMismatch dot(vec(X), Y2) + @test dot(X, Y) == dot(vec(X), vec(Y2)) end dot1(x,y) = invoke(dot, Tuple{Any,Any}, x,y) @@ -454,6 +458,21 @@ dot2(x,y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x,y) end end end + for elty in (Float32, Float64, ComplexF32, ComplexF64) + XX = convert(Matrix{elty},[1.0 2.0; 3.0 4.0]) + YY = convert(Matrix{elty},[1.5 2.5; 3.5 4.5]) + YY2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5]) + for X in (copy(XX), view(XX, 1:2, 1:2)), Y in (copy(YY), view(YY, 1:2, 1:2)), Y2 in (copy(YY2), view(YY2, 1:1, 1:4)) + @test dot1(X, Y) == convert(elty, 35.0) + @test dot2(X, Y) == convert(elty, 35.0) + @test dot1(X, Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes + @test_throws DimensionMismatch dot2(X, Y2) + @test dot1(vec(X), Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes + @test_throws DimensionMismatch dot2(vec(X), Y2) + @test dot1(X, Y) == dot1(vec(X), vec(Y2)) + @test dot2(X, Y) == dot2(vec(X), vec(Y2)) + end + end end @testset "Issue 11978" begin diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index 357d468b42e3e..b2a4ab4a15672 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -300,7 +300,9 @@ ilog2(n::Integer) = sizeof(n)<<3 - leading_zeros(n) # Frobenius dot/inner product: trace(A'B) function dot(A::AbstractSparseMatrixCSC{T1,S1},B::AbstractSparseMatrixCSC{T2,S2}) where {T1,T2,S1,S2} m, n = size(A) - size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions")) + if size(B) != (m,n) + throw(DimensionMismatch("The first array has size $(size(A)) which does not match the size of the second, $(size(B)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) + end r = dot(zero(T1), zero(T2)) @inbounds for j = 1:n ia = getcolptr(A)[j]; ia_nxt = getcolptr(A)[j+1] diff --git a/stdlib/SparseArrays/src/sparsevector.jl b/stdlib/SparseArrays/src/sparsevector.jl index 55ad738a7eb77..23899084565f3 100644 --- a/stdlib/SparseArrays/src/sparsevector.jl +++ b/stdlib/SparseArrays/src/sparsevector.jl @@ -1487,7 +1487,9 @@ end function dot(x::AbstractVector{Tx}, y::SparseVectorUnion{Ty}) where {Tx<:Number,Ty<:Number} require_one_based_indexing(x, y) n = length(x) - length(y) == n || throw(DimensionMismatch()) + if axes(x) != axes(y) + throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) + end nzind = nonzeroinds(y) nzval = nonzeros(y) s = dot(zero(Tx), zero(Ty)) @@ -1500,7 +1502,9 @@ end function dot(x::SparseVectorUnion{Tx}, y::AbstractVector{Ty}) where {Tx<:Number,Ty<:Number} require_one_based_indexing(x, y) n = length(y) - length(x) == n || throw(DimensionMismatch()) + if axes(x) != axes(y) + throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) + end nzind = nonzeroinds(x) nzval = nonzeros(x) s = dot(zero(Tx), zero(Ty)) @@ -1534,7 +1538,9 @@ end function dot(x::SparseVectorUnion{<:Number}, y::SparseVectorUnion{<:Number}) x === y && return sum(abs2, x) n = length(x) - length(y) == n || throw(DimensionMismatch()) + if axes(x) != axes(y) + throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`.")) + end xnzind = nonzeroinds(x) ynzind = nonzeroinds(y)