Skip to content

Commit

Permalink
Bind I to UniformScaling(true) rather than UniformScaling(1) for bett…
Browse files Browse the repository at this point in the history
…er promotion behavior.
  • Loading branch information
Sacha0 committed Oct 30, 2017
1 parent 12517bd commit 9f71294
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
20 changes: 10 additions & 10 deletions base/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ julia> [1 2im 3; 1im 2 3] * I
0+1im 2+0im 3+0im
```
"""
const I = UniformScaling(1)
const I = UniformScaling(true)

eltype(::Type{UniformScaling{T}}) where {T} = T
ndims(J::UniformScaling) = 2
Expand Down Expand Up @@ -99,7 +99,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
($op)(UL::$t2, J::UniformScaling) = ($t2)(($op)(UL.data, J))

function ($op)(UL::$t1, J::UniformScaling)
ULnew = copy_oftype(UL.data, promote_type(eltype(UL), eltype(J)))
ULnew = copy_oftype(UL.data, Base.Broadcast._broadcast_eltype($op, UL, J))
for i = 1:size(ULnew, 1)
ULnew[i,i] = ($op)(1, J.λ)
end
Expand All @@ -110,7 +110,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
end

function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), Base.Broadcast._broadcast_eltype(-, J, UL))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand All @@ -126,7 +126,7 @@ function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
return UpperTriangular(ULnew)
end
function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular})
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), Base.Broadcast._broadcast_eltype(-, J, UL))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand All @@ -142,28 +142,28 @@ function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular})
return LowerTriangular(ULnew)
end

function (+)(A::AbstractMatrix{TA}, J::UniformScaling{TJ}) where {TA,TJ}
function (+)(A::AbstractMatrix, J::UniformScaling)
n = checksquare(A)
B = similar(A, promote_type(TA,TJ))
B = similar(A, Base.Broadcast._broadcast_eltype(+, A, J))
copy!(B,A)
@inbounds for i = 1:n
B[i,i] += J.λ
end
B
end

function (-)(A::AbstractMatrix{TA}, J::UniformScaling{TJ}) where {TA,TJ<:Number}
function (-)(A::AbstractMatrix, J::UniformScaling)
n = checksquare(A)
B = similar(A, promote_type(TA,TJ))
B = similar(A, Base.Broadcast._broadcast_eltype(-, A, J))
copy!(B, A)
@inbounds for i = 1:n
B[i,i] -= J.λ
end
B
end
function (-)(J::UniformScaling{TJ}, A::AbstractMatrix{TA}) where {TA,TJ<:Number}
function (-)(J::UniformScaling, A::AbstractMatrix)
n = checksquare(A)
B = convert(AbstractMatrix{promote_type(TJ,TA)}, -A)
B = convert(AbstractMatrix{Base.Broadcast._broadcast_eltype(-, J, A)}, -A)
@inbounds for j = 1:n
B[j,j] += J.λ
end
Expand Down
2 changes: 1 addition & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ test5536(a::Union{Real, AbstractArray}) = "Non-splatting"
# issue #6142
import Base: +
mutable struct A6142 <: AbstractMatrix{Float64}; end
+(x::A6142, y::UniformScaling{TJ}) where {TJ} = "UniformScaling method called"
+(x::A6142, y::UniformScaling) = "UniformScaling method called"
+(x::A6142, y::AbstractArray) = "AbstractArray method called"
@test A6142() + I == "UniformScaling method called"
+(x::A6142, y::AbstractRange) = "AbstractRange method called" #16324 ambiguity
Expand Down
13 changes: 12 additions & 1 deletion test/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
end

@testset "det and logdet" begin
@test det(I) === 1
@test det(I) === true
@test det(1.0I) === 1.0
@test det(0I) === 0
@test det(0.0I) === 0.0
Expand Down Expand Up @@ -189,3 +189,14 @@ end
@test_throws LinAlg.PosDefException chol(-λ*I)
end
end

@testset "operations involving I should preserve eltype" begin
@test isa(Int8(1) + I, Int8)
@test isa(Float16(1) + I, Float16)
@test eltype(Int8(1)I) == Int8
@test eltype(Float16(1)I) == Float16
@test eltype(fill(Int8(1), 2, 2)I) == Int8
@test eltype(fill(Float16(1), 2, 2)I) == Float16
@test eltype(fill(Int8(1), 2, 2) + I) == Int8
@test eltype(fill(Float16(1), 2, 2) + I) == Float16
end

0 comments on commit 9f71294

Please sign in to comment.