Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ function triu!(M::Bidiagonal, k::Integer=0)
return M
end

function diag(M::Bidiagonal{T}, n::Integer=0) where T
function diag(M::Bidiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 :)

if n == 0
return M.dv
elseif n == 1
return M.uplo == 'U' ? M.ev : zeros(T, size(M,1)-1)
elseif n == -1
return M.uplo == 'L' ? M.ev : zeros(T, size(M,1)-1)
return copy!(similar(M.dv, length(M.dv)), M.dv)
Copy link
Member

@Sacha0 Sacha0 Oct 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the length(M.dv) be omitted? (Edit: Guarding against the length specification modifying the return type?)

Copy link
Member

@Sacha0 Sacha0 Oct 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moreover, why not simply copy(... rather than copy!(similar(M.dv, length(M.dv)), M.dv)? Does the return type otherwise depend upon k in some cases (e.g. for ranges)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guarding against the length specification modifying the return type?

Yes, that was the idea, to guarantee that we take the same path no matter which diagonal was requested.

Moreover, why not simply copy(... rather than copy!(similar(M.dv, length(M.dv)), M.dv)? Does the return type otherwise depend upon k in some cases (e.g. for ranges)?

Same comment applies; guarantee the same code path. Also copy(r::AbstractRange) = r so that will not yield a Vector{Int} as similar does.

julia> copy(1:4)
1:4

I don't think the copy!(similar(...), ...) uses are that terrible -- it is pretty clear what is going on IMO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good then! I might consider adding comments to make the intention clear for future readers; to someone not acquainted with the subtleties involved, the intention may be far from clear :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, Ill add that

elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L')
return copy!(similar(M.ev, length(M.ev)), M.ev)
elseif -size(M,1) <= n <= size(M,1)
return zeros(T, size(M,1)-abs(n))
return fill!(similar(M.dv, size(M,1)-abs(n)), 0)
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down
13 changes: 12 additions & 1 deletion base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,18 @@ transpose(D::Diagonal) = Diagonal(transpose.(D.diag))
adjoint(D::Diagonal{<:Number}) = conj(D)
adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))

diag(D::Diagonal) = D.diag
function diag(D::Diagonal, k::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of k
if k == 0
return copy!(similar(D.diag, length(D.diag)), D.diag)
elseif -size(D,1) <= k <= size(D,1)
return fill!(similar(D.diag, size(D,1)-abs(k)), 0)
else
throw(ArgumentError(string("requested diagonal, $k, must be at least $(-size(D, 1)) ",
"and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix")))
end
end
trace(D::Diagonal) = sum(D.diag)
det(D::Diagonal) = prod(D.diag)
logdet(D::Diagonal{<:Real}) = sum(log, D.diag)
Expand Down
20 changes: 12 additions & 8 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,16 @@ broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = Sym
transpose(M::SymTridiagonal) = M #Identity operation
adjoint(M::SymTridiagonal) = conj(M)

function diag(M::SymTridiagonal{T}, n::Integer=0) where T
function diag(M::SymTridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
absn = abs(n)
if absn == 0
return M.dv
return copy!(similar(M.dv, length(M.dv)), M.dv)
elseif absn==1
return M.ev
return copy!(similar(M.ev, length(M.ev)), M.ev)
elseif absn <= size(M,1)
return zeros(T,size(M,1)-absn)
return fill!(similar(M.dv, size(M,1)-absn), 0)
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down Expand Up @@ -535,14 +537,16 @@ transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl)
adjoint(M::Tridiagonal) = conj(transpose(M))

function diag(M::Tridiagonal{T}, n::Integer=0) where T
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
if n == 0
return M.d
return copy!(similar(M.d, length(M.d)), M.d)
elseif n == -1
return M.dl
return copy!(similar(M.dl, length(M.dl)), M.dl)
elseif n == 1
return M.du
return copy!(similar(M.du, length(M.du)), M.du)
elseif abs(n) <= size(M,1)
return zeros(T,size(M,1)-abs(n))
return fill!(similar(M.d, size(M,1)-abs(n)), 0)
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down
14 changes: 11 additions & 3 deletions test/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,18 @@ srand(1)
end
end

@testset "Diagonals" begin
@test diag(T,2) == zeros(elty, n-2)
@testset "diag" begin
@test (@inferred diag(T))::typeof(dv) == dv
@test (@inferred diag(T, uplo == :U ? 1 : -1))::typeof(dv) == ev
@test (@inferred diag(T,2))::typeof(dv) == zeros(elty, n-2)
@test_throws ArgumentError diag(T, -n - 1)
@test_throws ArgumentError diag(T, n + 1)
@test_throws ArgumentError diag(T, n + 1)
# test diag with another wrapped vector type
gdv, gev = GenericArray(dv), GenericArray(ev)
G = Bidiagonal(gdv, gev, uplo)
@test (@inferred diag(G))::typeof(gdv) == gdv
@test (@inferred diag(G, uplo == :U ? 1 : -1))::typeof(gdv) == gev
@test (@inferred diag(G,2))::typeof(gdv) == GenericArray(zeros(elty, n-2))
end

@testset "Eigensystems" begin
Expand Down
13 changes: 12 additions & 1 deletion test/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ srand(1)
@test Array(imag(D)) == imag(DM)

@test parent(D) == dd
@test diag(D) == dd
@test D[1,1] == dd[1]
@test D[1,2] == 0

Expand All @@ -51,6 +50,18 @@ srand(1)
end
end

@testset "diag" begin
@test_throws ArgumentError diag(D, n+1)
@test_throws ArgumentError diag(D, -n-1)
@test (@inferred diag(D))::typeof(dd) == dd
@test (@inferred diag(D, 0))::typeof(dd) == dd
@test (@inferred diag(D, 1))::typeof(dd) == zeros(elty, n-1)
DG = Diagonal(GenericArray(dd))
@test (@inferred diag(DG))::typeof(GenericArray(dd)) == GenericArray(dd)
@test (@inferred diag(DG, 1))::typeof(GenericArray(dd)) == GenericArray(zeros(elty, n-1))
end


@testset "Simple unary functions" begin
for op in (-,)
@test op(D)==op(DM)
Expand Down
15 changes: 9 additions & 6 deletions test/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,17 @@ guardsrand(123) do
@test_throws ArgumentError A[2, 3] = 1 # test assignment on the superdiagonal
end
end
@testset "Diagonal extraction" begin
@test diag(A, 1) === (mat_type == Tridiagonal ? du : dl)
@test diag(A, -1) === dl
@test diag(A, 0) === d
@test diag(A) === d
@test diag(A, n - 1) == zeros(elty, 1)
@testset "diag" begin
@test (@inferred diag(A))::typeof(d) == d
@test (@inferred diag(A, 0))::typeof(d) == d
@test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl)
@test (@inferred diag(A, -1))::typeof(d) == dl
@test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1)
@test_throws ArgumentError diag(A, -n - 1)
@test_throws ArgumentError diag(A, n + 1)
GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...)
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
end
@testset "Idempotent tests" begin
for func in (conj, transpose, adjoint)
Expand Down