Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check lengths in covector-vector products #36679

Merged
merged 12 commits into from
Jul 22, 2020
29 changes: 19 additions & 10 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,23 @@ Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...)

## multiplication *

function _dot_nonrecursive(u, v)
lu = length(u)
if lu != length(v)
throw(DimensionMismatch("first array has length $(lu) which does not match the length of the second, $(length(v))."))
end
if lu == 0
Comment on lines +270 to +275
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re OffsetArrays support, I think it's enough to check axes like this?

Suggested change
function _dot_nonrecursive(u, v)
lu = length(u)
if lu != length(v)
throw(DimensionMismatch("first array has length $(lu) which does not match the length of the second, $(length(v))."))
end
if lu == 0
function _dot_nonrecursive(u::AbstractVecOrMat, v::AbstractVecOrMat)
if !(axes(u, 1) == axes(v, 2) && axes(u, 2) == axes(v, 1))
throw(DimensionMismatch("dimensions of the first array $(axes(u)) and the second array $(axes(v)) are incompatible for a dot product."))
end
if isempty(u)

I think this is a kind of error that would have thrown if we use here promote_shape(u', v) (with a better error message). Since something like eachindex(a, b) already throws an error with something like eachindex(1:3, OffsetArray(1:3, -2:0)) simply based on axes, maybe it makes sense to check axes here?

Copy link
Contributor Author

@MasonProtter MasonProtter Jul 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could just do

@inbounds sum(u[i]*v[i] for i in eachindex(u, v))

However, I'm still not actually convinced it's even desirable to throw an error here. I'm also not convinced we shouldn't throw an error, but I will point out that that in the current state of affairs transpose(u::OffsetVector) * v::Vector currently has the same behaviour as the implementation in this PR (i.e. no error).

I feel like it'd be good at least for 1.5 since it's so near, to just follow the lead of transpose(u) * v rather than being over eager about throwing errors.

julia> using OffsetArrays

julia> let u = OffsetArray([1,2,3], -1:1), v = [1,2,3]
           transpose(u) * v
       end
14

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm rather inclined towards throwing an error, as in the matrix-vector case. If the user really wants to contract an o::OffsetArray with a v::Vector, it is not complicated to do OffsetArrays.no_offset_view(o)' * v. Perhaps no_offset_view should be exported.

zero(eltype(u)) * zero(eltype(v))
else
sum(uu*vv for (uu, vv) in zip(u, v))
end
Comment on lines +275 to +279
Copy link
Member

@tkf tkf Jul 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are manually computing the zero length case anyway, sum(...; init = zero(eltype(u)) * zero(eltype(v))) might be easier on the compiler and (potentially) improve the performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

Copy link
Contributor Author

@MasonProtter MasonProtter Jul 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so this is actually problematic because it makes u' * v not work for any u or v for which zero(eltype(_)) isn't defined (i.e. things like arrays of arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to undo the suggested change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, this reads basically exactly (up to the * vs dot) as:

function dot(x::AbstractArray, y::AbstractArray)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
end
if lx == 0
return dot(zero(eltype(x)), zero(eltype(y)))
end
s = zero(dot(first(x), first(y)))
for (Ix, Iy) in zip(eachindex(x), eachindex(y))
@inbounds s += dot(x[Ix], y[Iy])
end
s
end

which I think is good.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zero(eltype(_)) isn't defined (i.e. things like arrays of arrays.

Actually I feel like I did a similar suggestion before here that was ended up useless exactly due to this... I should have learned from a mistake.

end

# Adjoint/Transpose-vector * vector
*(u::AdjointAbsVec{T}, v::AbstractVector{T}) where {T<:Number} = dot(u.parent, v)
*(u::AdjointAbsVec{<:Number}, v::AbstractVector{<:Number}) = dot(u.parent, v)
*(u::TransposeAbsVec{T}, v::AbstractVector{T}) where {T<:Real} = dot(u.parent, v)
Comment on lines +283 to 284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding that change in. Any reason not to do the equivalent thing in 284 as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I merged them, and now it looks very clean, actually. 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's causing method ambiguities that I'm not in the mood now of resolving. I've reverted that. I think the *(::TransVector,::AbstractVector) is handled somewhere in matmul.jl.

*(u::AdjOrTransAbsVec, v::AbstractVector) = sum(uu*vv for (uu, vv) in zip(u, v))
*(u::AdjOrTransAbsVec, v::AbstractVector) = _dot_nonrecursive(u, v)


# vector * Adjoint/Transpose-vector
*(u::AbstractVector, v::AdjOrTransAbsVec) = broadcast(*, u, v)
Expand All @@ -281,14 +294,10 @@ Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...)

# AdjOrTransAbsVec{<:Any,<:AdjOrTransAbsVec} is a lazy conj vectors
# We need to expand the combinations to avoid ambiguities
(*)(u::TransposeAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) =
sum(uu*vv for (uu, vv) in zip(u, v))
(*)(u::AdjointAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) =
sum(uu*vv for (uu, vv) in zip(u, v))
(*)(u::TransposeAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) =
sum(uu*vv for (uu, vv) in zip(u, v))
(*)(u::AdjointAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) =
sum(uu*vv for (uu, vv) in zip(u, v))
(*)(u::TransposeAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) = _dot_nonrecursive(u, v)
(*)(u::AdjointAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) = _dot_nonrecursive(u, v)
(*)(u::TransposeAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) = _dot_nonrecursive(u, v)
(*)(u::AdjointAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) = _dot_nonrecursive(u, v)

## pseudoinversion
pinv(v::AdjointAbsVec, tol::Real = 0) = pinv(v.parent, tol).parent
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,4 +550,11 @@ end
@test conj(transpose(hermitian)) === hermitian
end

@testset "empty and mismatched lengths" begin
# issue 36678
@test_throws DimensionMismatch [1, 2]' * [1,2,3]
@test Int[]' * Int[] == 0
@test transpose(Int[]) * Int[] == 0
end

end # module TestAdjointTranspose