Skip to content

Commit

Permalink
Merge pull request #1871 from stevengj/linalg_scalar
Browse files Browse the repository at this point in the history
RFC: extend linalg and array operations to work on Numbers when it makes sense
  • Loading branch information
ViralBShah committed Jan 2, 2013
2 parents 15cb385 + b8f47e8 commit 0d1730c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 13 deletions.
23 changes: 15 additions & 8 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ end

## find ##

function nnz(a::StridedArray)
function nnz(a)
n = 0
for i = 1:numel(a)
n += bool(a[i]) ? 1 : 0
Expand All @@ -1101,7 +1101,7 @@ function nnz(a::StridedArray)
end

# returns the index of the first non-zero element, or 0 if all zeros
function findfirst(A::StridedArray)
function findfirst(A)
for i = 1:length(A)
if A[i] != 0
return i
Expand All @@ -1111,7 +1111,7 @@ function findfirst(A::StridedArray)
end

# returns the index of the first matching element
function findfirst(A::StridedArray, v)
function findfirst(A, v)
for i = 1:length(A)
if A[i] == v
return i
Expand All @@ -1121,7 +1121,7 @@ function findfirst(A::StridedArray, v)
end

# returns the index of the first element for which the function returns true
function findfirst(testf::Function, A::StridedArray)
function findfirst(testf::Function, A)
for i = 1:length(A)
if testf(A[i])
return i
Expand Down Expand Up @@ -1157,6 +1157,10 @@ function find(A::StridedArray)
return I
end

find(x::Number) = x == 0 ? Array(Int,0) : [1]
find(x::Bool) = x ? [1] : Array(Int,0)
find(testf::Function, x) = find(testf(x))

findn(A::StridedVector) = find(A)

function findn(A::StridedMatrix)
Expand Down Expand Up @@ -1241,7 +1245,9 @@ function nonzeros{T}(A::StridedArray{T})
return V
end

function findmax(a::StridedArray)
nonzeros(x::Number) = x == 0 ? Array(typeof(x),0) : [x]

function findmax(a)
m = typemin(eltype(a))
mi = 0
for i=1:length(a)
Expand All @@ -1253,7 +1259,7 @@ function findmax(a::StridedArray)
return (m, mi)
end

function findmin(a::StridedArray)
function findmin(a)
m = typemax(eltype(a))
mi = 0
for i=1:length(a)
Expand All @@ -1264,8 +1270,9 @@ function findmin(a::StridedArray)
end
return (m, mi)
end
indmax(a::StridedArray) = findmax(a)[2]
indmin(a::StridedArray) = findmin(a)[2]

indmax(a) = findmax(a)[2]
indmin(a) = findmin(a)[2]

## Reductions ##

Expand Down
15 changes: 15 additions & 0 deletions base/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,21 @@ function norm(A::AbstractMatrix, p)
end

norm(A::AbstractMatrix) = norm(A, 2)

norm(x::Number) = abs(x)
norm(x::Number, p) = abs(x)

rank(A::AbstractMatrix, tol::Real) = sum(svdvals(A) .> tol)
function rank(A::AbstractMatrix)
m,n = size(A)
if m == 0 || n == 0; return 0; end
sv = svdvals(A)
sum(sv .> max(size(A,1),size(A,2))*eps(sv[1]))
end
rank(x::Number) = x == 0 ? 0 : 1

trace(A::AbstractMatrix) = sum(diag(A))
trace(x::Number) = x

#kron(a::AbstractVector, b::AbstractVector)
#kron{T,S}(a::AbstractMatrix{T}, b::AbstractMatrix{S})
Expand All @@ -109,6 +115,9 @@ function cond(a::AbstractMatrix, p)
end
end

cond(x::Number) = x == 0 ? Inf : 1
cond(x::Number, p) = cond(x)

function issym(A::AbstractMatrix)
m, n = size(A)
if m != n; error("matrix must be square, got $(m)x$(n)"); end
Expand All @@ -120,6 +129,8 @@ function issym(A::AbstractMatrix)
return true
end

issym(x::Number) = true

function ishermitian(A::AbstractMatrix)
m, n = size(A)
if m != n; error("matrix must be square, got $(m)x$(n)"); end
Expand All @@ -131,6 +142,8 @@ function ishermitian(A::AbstractMatrix)
return true
end

ishermitian(x::Number) = isreal(x)

function istriu(A::AbstractMatrix)
m, n = size(A)
for j = 1:min(n,m-1), i = j+1:m
Expand All @@ -151,6 +164,8 @@ function istril(A::AbstractMatrix)
return true
end

istriu(x::Number) = true
istril(x::Number) = true

function linreg{T<:Number}(X::StridedVecOrMat{T}, y::Vector{T})
[ones(T, size(X,1)) X] \ y
Expand Down
29 changes: 25 additions & 4 deletions base/linalg_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ isposdef{T<:BlasFloat}(A::Matrix{T}) = isposdef!(copy(A))
isposdef{T<:Number}(A::Matrix{T}, upper::Bool) = isposdef!(float64(A), upper)
isposdef{T<:Number}(A::Matrix{T}) = isposdef!(float64(A))

isposdef(x::Number) = isreal(x) && x > 0

norm{T<:BlasFloat}(x::Vector{T}) = BLAS.nrm2(length(x), x, 1)

function norm{T<:BlasFloat, TI<:Integer}(x::Vector{T}, rx::Union(Range1{TI},Range{TI}))
Expand Down Expand Up @@ -137,6 +139,8 @@ end

diagm(v) = diagm(v, 0)

diagm(x::Number) = (X = Array(typeof(x),1,1); X[1,1] = x; X)

function trace{T}(A::Matrix{T})
t = zero(T)
for i=1:min(size(A))
Expand Down Expand Up @@ -165,6 +169,12 @@ function kron{T,S}(a::Matrix{T}, b::Matrix{S})
R
end

kron(a::Number, b::Number) = a * b
kron(a::Vector, b::Number) = a * b
kron(a::Number, b::Vector) = a * b
kron(a::Matrix, b::Number) = a * b
kron(a::Number, b::Matrix) = a * b

function randsym(n)
a = randn(n,n)
for j=1:n-1, i=j+1:n
Expand Down Expand Up @@ -235,6 +245,8 @@ function rref{T}(A::Matrix{T})
return U
end

rref(x::Number) = one(x)

## Destructive matrix exponential using algorithm from Higham, 2008,
## "Functions of Matrices: Theory and Computation", SIAM
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
Expand Down Expand Up @@ -377,6 +389,7 @@ end
# Matrix exponential
expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm!(copy(A))
expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A))
expm(x::Number) = exp(x)

## Matrix factorizations and decompositions

Expand Down Expand Up @@ -451,6 +464,7 @@ chold{T<:Number}(A::Matrix{T}) = chold(A, true)

## Matlab (and R) compatible
chol{T<:Number}(A::Matrix{T}) = factors(chold(A))
chol(x::Number) = isreal(x) && x >= 0 ? sqrt(x) : error("argument not positive-definite")

type CholeskyDensePivoted{T<:BlasFloat} <: Factorization{T}
LR::Matrix{T}
Expand Down Expand Up @@ -537,6 +551,7 @@ lud{T<:Number}(A::Matrix{T}) = lud(float64(A))

## Matlab-compatible
lu{T<:Number}(A::Matrix{T}) = factors(lud(A))
lu(x::Number) = (one(x), x)

function det{T}(lu::LUDense{T})
m, n = size(lu)
Expand All @@ -551,6 +566,8 @@ function det(A::Matrix)
return det(lud(A))
end

det(x::Number) = x

function (\){T<:BlasFloat}(lu::LUDense{T}, B::StridedVecOrMat{T})
if lu.info > 0; throw(LAPACK.SingularException(info)); end
LAPACK.getrs!('N', lu.lu, lu.ipiv, copy(B))
Expand Down Expand Up @@ -585,6 +602,7 @@ function factors{T<:BlasFloat}(qrd::QRDense{T})
end

qr{T<:Number}(x::StridedMatrix{T}) = factors(qrd(x))
qr(x::Number) = (one(x), x)

## Multiplication by Q from the QR decomposition
(*){T<:BlasFloat}(A::QRDense{T}, B::StridedVecOrMat{T}) =
Expand Down Expand Up @@ -688,8 +706,9 @@ function eig{T<:BlasFloat}(A::StridedMatrix{T}, vecs::Bool)
end

eig{T<:Integer}(x::StridedMatrix{T}, vecs::Bool) = eig(float64(x), vecs)
eig(x::StridedMatrix) = eig(x, true)
eigvals(x::StridedMatrix) = eig(x, false)
eig(x::Number, vecs::Bool) = vecs ? (x, one(x)) : x
eig(x) = eig(x, true)
eigvals(x) = eig(x, false)

# This is the svd based on the LAPACK GESVD, which is slower, but takes
# lesser memory. It should be made available through a keyword argument
Expand Down Expand Up @@ -721,8 +740,9 @@ function svd{T<:BlasFloat}(A::StridedMatrix{T},vecs::Bool,thin::Bool)
end

svd{T<:Integer}(x::StridedMatrix{T},vecs,thin) = svd(float64(x),vecs,thin)
svd(A::StridedMatrix) = svd(A,true,false)
svd(A::StridedMatrix, thin::Bool) = svd(A,true,thin)
svd(x::Number,vecs::Bool,thin::Bool) = vecs ? (x==0?one(x):x/abs(x),abs(x),one(x)) : ([],abs(x),[])
svd(A) = svd(A,true,false)
svd(A, thin::Bool) = svd(A,true,thin)
svdvals(A) = svd(A,false,true)[2]

function (\){T<:BlasFloat}(A::StridedMatrix{T}, B::StridedVecOrMat{T})
Expand Down Expand Up @@ -776,6 +796,7 @@ function pinv{T<:BlasFloat}(A::StridedMatrix{T})
end
pinv{T<:Integer}(A::StridedMatrix{T}) = pinv(float(A))
pinv(a::StridedVector) = pinv(reshape(a, length(a), 1))
pinv(x::Number) = one(x)/x

## Basis for null space
function null{T<:BlasFloat}(A::StridedMatrix{T})
Expand Down
2 changes: 2 additions & 0 deletions base/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ function dot(x::Vector, y::Vector)
s
end

dot(x::Number, y::Number) = conj(x) * y

# Matrix-vector multiplication

function (*){T<:BlasFloat}(A::StridedMatrix{T},
Expand Down
1 change: 0 additions & 1 deletion base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ ctranspose(x::Number) = conj(x)
inv(x::Number) = one(x)/x
angle(z::Real) = atan2(zero(z), z)

# TODO: should we really treat numbers as iterable?
start(a::Real) = a
next(a::Real, i) = (a, a+1)
done(a::Real, i) = (i > a)
Expand Down

0 comments on commit 0d1730c

Please sign in to comment.