Skip to content

Commit 8c47b5c

Browse files
committed
WIP fix diag(::StructuredMatrix[, k=0])
1 parent b3d0f9d commit 8c47b5c

File tree

6 files changed

+55
-24
lines changed

6 files changed

+55
-24
lines changed

base/linalg/bidiag.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,13 @@ end
272272

273273
function diag(M::Bidiagonal{T}, n::Integer=0) where T
274274
if n == 0
275-
return M.dv
276-
elseif n == 1
277-
return M.uplo == 'U' ? M.ev : zeros(T, size(M,1)-1)
278-
elseif n == -1
279-
return M.uplo == 'L' ? M.ev : zeros(T, size(M,1)-1)
275+
return copy(M.dv)
276+
elseif n == 1 && M.uplo == 'U'
277+
return copy(M.ev)
278+
elseif n == -1 && M.uplo == 'L'
279+
return copy(M.ev)
280280
elseif -size(M,1) <= n <= size(M,1)
281-
return zeros(T, size(M,1)-abs(n))
281+
return fill!(similar(M.dv, size(M,1)-abs(n)), 0)
282282
else
283283
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
284284
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))

base/linalg/diagonal.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,16 @@ transpose(D::Diagonal) = Diagonal(transpose.(D.diag))
322322
adjoint(D::Diagonal{<:Number}) = conj(D)
323323
adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))
324324

325-
diag(D::Diagonal) = D.diag
325+
function diag(D::Diagonal, k::Integer=0)
326+
if k == 0
327+
return copy(D.diag)
328+
elseif -size(D,1) <= k <= size(D,1)
329+
return fill!(similar(D.diag, size(D,1)-abs(k)), 0)
330+
else
331+
throw(ArgumentError(string("requested diagonal, $k, must be at least $(-size(D, 1)) ",
332+
"and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix")))
333+
end
334+
end
326335
trace(D::Diagonal) = sum(D.diag)
327336
det(D::Diagonal) = prod(D.diag)
328337
logdet(D::Diagonal{<:Real}) = sum(log, D.diag)

base/linalg/tridiag.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ adjoint(M::SymTridiagonal) = conj(M)
133133
function diag(M::SymTridiagonal{T}, n::Integer=0) where T
134134
absn = abs(n)
135135
if absn == 0
136-
return M.dv
136+
return copy(M.dv)
137137
elseif absn==1
138-
return M.ev
138+
return copy(M.ev)
139139
elseif absn <= size(M,1)
140-
return zeros(T,size(M,1)-absn)
140+
return fill!(similar(M.dv, size(M,1)-absn), 0)
141141
else
142142
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
143143
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
@@ -536,13 +536,13 @@ adjoint(M::Tridiagonal) = conj(transpose(M))
536536

537537
function diag(M::Tridiagonal{T}, n::Integer=0) where T
538538
if n == 0
539-
return M.d
539+
return copy(M.d)
540540
elseif n == -1
541-
return M.dl
541+
return copy(M.dl)
542542
elseif n == 1
543-
return M.du
543+
return copy(M.du)
544544
elseif abs(n) <= size(M,1)
545-
return zeros(T,size(M,1)-abs(n))
545+
return fill!(similar(M.d, size(M,1)-abs(n)), 0)
546546
else
547547
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
548548
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))

test/linalg/bidiag.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,18 @@ srand(1)
216216
end
217217
end
218218

219-
@testset "Diagonals" begin
220-
@test diag(T,2) == zeros(elty, n-2)
219+
@testset "diag" begin
220+
@test (@inferred diag(T))::typeof(dv) == dv
221+
@test (@inferred diag(T, uplo == :U ? 1 : -1))::typeof(dv) == ev
222+
@test (@inferred diag(T,2))::typeof(dv) == zeros(elty, n-2)
221223
@test_throws ArgumentError diag(T, -n - 1)
222-
@test_throws ArgumentError diag(T, n + 1)
224+
@test_throws ArgumentError diag(T, n + 1)
225+
# test diag with another wrapped vector type
226+
gdv, gev = GenericArray(dv), GenericArray(ev)
227+
G = Bidiagonal(gdv, gev, uplo)
228+
@test (@inferred diag(G))::typeof(gdv) == gdv
229+
@test (@inferred diag(G, uplo == :U ? 1 : -1))::typeof(gdv) == gev
230+
@test (@inferred diag(G,2))::typeof(gdv) == GenericArray(zeros(elty, n-2))
223231
end
224232

225233
@testset "Eigensystems" begin

test/linalg/diagonal.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ srand(1)
3939
@test Array(imag(D)) == imag(DM)
4040

4141
@test parent(D) == dd
42-
@test diag(D) == dd
4342
@test D[1,1] == dd[1]
4443
@test D[1,2] == 0
4544

@@ -51,6 +50,18 @@ srand(1)
5150
end
5251
end
5352

53+
@testset "diag" begin
54+
@test_throws ArgumentError diag(D, n+1)
55+
@test_throws ArgumentError diag(D, -n-1)
56+
@test (@inferred diag(D))::typeof(dd) == dd
57+
@test (@inferred diag(D, 0))::typeof(dd) == dd
58+
@test (@inferred diag(D, 1))::typeof(dd) == zeros(elty, n-1)
59+
DG = Diagonal(GenericArray(dd))
60+
@test (@inferred diag(DG))::typeof(GenericArray(dd)) == GenericArray(dd)
61+
@test (@inferred diag(DG, 1))::typeof(GenericArray(dd)) == GenericArray(zeros(elty, n-1))
62+
end
63+
64+
5465
@testset "Simple unary functions" begin
5566
for op in (-,)
5667
@test op(D)==op(DM)

test/linalg/tridiag.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,17 @@ guardsrand(123) do
153153
@test_throws ArgumentError A[2, 3] = 1 # test assignment on the superdiagonal
154154
end
155155
end
156-
@testset "Diagonal extraction" begin
157-
@test diag(A, 1) === (mat_type == Tridiagonal ? du : dl)
158-
@test diag(A, -1) === dl
159-
@test diag(A, 0) === d
160-
@test diag(A) === d
161-
@test diag(A, n - 1) == zeros(elty, 1)
156+
@testset "diag" begin
157+
@test (@inferred diag(A))::typeof(d) == d
158+
@test (@inferred diag(A, 0))::typeof(d) == d
159+
@test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl)
160+
@test (@inferred diag(A, -1))::typeof(d) == dl
161+
@test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1)
162162
@test_throws ArgumentError diag(A, -n - 1)
163163
@test_throws ArgumentError diag(A, n + 1)
164+
GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...)
165+
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
166+
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
164167
end
165168
@testset "Idempotent tests" begin
166169
for func in (conj, transpose, adjoint)

0 commit comments

Comments
 (0)