Skip to content

Commit

Permalink
extend linalg and array operations to work on Numbers when it makes s…
Browse files Browse the repository at this point in the history
…ense and is unambiguous
  • Loading branch information
stevengj committed Jan 2, 2013
1 parent 464f407 commit b8f47e8
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 13 deletions.
23 changes: 15 additions & 8 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ end

## find ##

function nnz(a::StridedArray)
function nnz(a)
n = 0
for i = 1:numel(a)
n += bool(a[i]) ? 1 : 0
Expand All @@ -1101,7 +1101,7 @@ function nnz(a::StridedArray)
end

# returns the index of the first non-zero element, or 0 if all zeros
function findfirst(A::StridedArray)
function findfirst(A)
for i = 1:length(A)
if A[i] != 0
return i
Expand All @@ -1111,7 +1111,7 @@ function findfirst(A::StridedArray)
end

# returns the index of the first matching element
function findfirst(A::StridedArray, v)
function findfirst(A, v)
for i = 1:length(A)
if A[i] == v
return i
Expand All @@ -1121,7 +1121,7 @@ function findfirst(A::StridedArray, v)
end

# returns the index of the first element for which the function returns true
function findfirst(testf::Function, A::StridedArray)
function findfirst(testf::Function, A)
for i = 1:length(A)
if testf(A[i])
return i
Expand Down Expand Up @@ -1157,6 +1157,10 @@ function find(A::StridedArray)
return I
end

find(x::Number) = x == 0 ? Array(Int,0) : [1]
find(x::Bool) = x ? [1] : Array(Int,0)
find(testf::Function, x) = find(testf(x))

findn(A::StridedVector) = find(A)

function findn(A::StridedMatrix)
Expand Down Expand Up @@ -1241,7 +1245,9 @@ function nonzeros{T}(A::StridedArray{T})
return V
end

function findmax(a::StridedArray)
nonzeros(x::Number) = x == 0 ? Array(typeof(x),0) : [x]

function findmax(a)
m = typemin(eltype(a))
mi = 0
for i=1:length(a)
Expand All @@ -1253,7 +1259,7 @@ function findmax(a::StridedArray)
return (m, mi)
end

function findmin(a::StridedArray)
function findmin(a)
m = typemax(eltype(a))
mi = 0
for i=1:length(a)
Expand All @@ -1264,8 +1270,9 @@ function findmin(a::StridedArray)
end
return (m, mi)
end
indmax(a::StridedArray) = findmax(a)[2]
indmin(a::StridedArray) = findmin(a)[2]

indmax(a) = findmax(a)[2]
indmin(a) = findmin(a)[2]

## Reductions ##

Expand Down
15 changes: 15 additions & 0 deletions base/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,21 @@ function norm(A::AbstractMatrix, p)
end

norm(A::AbstractMatrix) = norm(A, 2)

norm(x::Number) = abs(x)
norm(x::Number, p) = abs(x)

rank(A::AbstractMatrix, tol::Real) = sum(svdvals(A) .> tol)
function rank(A::AbstractMatrix)
m,n = size(A)
if m == 0 || n == 0; return 0; end
sv = svdvals(A)
sum(sv .> max(size(A,1),size(A,2))*eps(sv[1]))
end
rank(x::Number) = x == 0 ? 0 : 1

trace(A::AbstractMatrix) = sum(diag(A))
trace(x::Number) = x

#kron(a::AbstractVector, b::AbstractVector)
#kron{T,S}(a::AbstractMatrix{T}, b::AbstractMatrix{S})
Expand All @@ -109,6 +115,9 @@ function cond(a::AbstractMatrix, p)
end
end

cond(x::Number) = x == 0 ? Inf : 1
cond(x::Number, p) = cond(x)

function issym(A::AbstractMatrix)
m, n = size(A)
if m != n; error("matrix must be square, got $(m)x$(n)"); end
Expand All @@ -120,6 +129,8 @@ function issym(A::AbstractMatrix)
return true
end

issym(x::Number) = true

function ishermitian(A::AbstractMatrix)
m, n = size(A)
if m != n; error("matrix must be square, got $(m)x$(n)"); end
Expand All @@ -131,6 +142,8 @@ function ishermitian(A::AbstractMatrix)
return true
end

ishermitian(x::Number) = isreal(x)

function istriu(A::AbstractMatrix)
m, n = size(A)
for j = 1:min(n,m-1), i = j+1:m
Expand All @@ -151,6 +164,8 @@ function istril(A::AbstractMatrix)
return true
end

istriu(x::Number) = true
istril(x::Number) = true

function linreg{T<:Number}(X::StridedVecOrMat{T}, y::Vector{T})
[ones(T, size(X,1)) X] \ y
Expand Down
29 changes: 25 additions & 4 deletions base/linalg_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ isposdef{T<:BlasFloat}(A::Matrix{T}) = isposdef!(copy(A))
isposdef{T<:Number}(A::Matrix{T}, upper::Bool) = isposdef!(float64(A), upper)
isposdef{T<:Number}(A::Matrix{T}) = isposdef!(float64(A))

isposdef(x::Number) = isreal(x) && x > 0

norm{T<:BlasFloat}(x::Vector{T}) = BLAS.nrm2(length(x), x, 1)

function norm{T<:BlasFloat, TI<:Integer}(x::Vector{T}, rx::Union(Range1{TI},Range{TI}))
Expand Down Expand Up @@ -137,6 +139,8 @@ end

diagm(v) = diagm(v, 0)

diagm(x::Number) = (X = Array(typeof(x),1,1); X[1,1] = x; X)

function trace{T}(A::Matrix{T})
t = zero(T)
for i=1:min(size(A))
Expand Down Expand Up @@ -165,6 +169,12 @@ function kron{T,S}(a::Matrix{T}, b::Matrix{S})
R
end

kron(a::Number, b::Number) = a * b
kron(a::Vector, b::Number) = a * b
kron(a::Number, b::Vector) = a * b
kron(a::Matrix, b::Number) = a * b
kron(a::Number, b::Matrix) = a * b

function randsym(n)
a = randn(n,n)
for j=1:n-1, i=j+1:n
Expand Down Expand Up @@ -235,6 +245,8 @@ function rref{T}(A::Matrix{T})
return U
end

rref(x::Number) = one(x)

## Destructive matrix exponential using algorithm from Higham, 2008,
## "Functions of Matrices: Theory and Computation", SIAM
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
Expand Down Expand Up @@ -377,6 +389,7 @@ end
# Matrix exponential
expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm!(copy(A))
expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A))
expm(x::Number) = exp(x)

## Matrix factorizations and decompositions

Expand Down Expand Up @@ -451,6 +464,7 @@ chold{T<:Number}(A::Matrix{T}) = chold(A, true)

## Matlab (and R) compatible
chol{T<:Number}(A::Matrix{T}) = factors(chold(A))
chol(x::Number) = isreal(x) && x >= 0 ? sqrt(x) : error("argument not positive-definite")

type CholeskyDensePivoted{T<:BlasFloat} <: Factorization{T}
LR::Matrix{T}
Expand Down Expand Up @@ -537,6 +551,7 @@ lud{T<:Number}(A::Matrix{T}) = lud(float64(A))

## Matlab-compatible
lu{T<:Number}(A::Matrix{T}) = factors(lud(A))
lu(x::Number) = (one(x), x)

function det{T}(lu::LUDense{T})
m, n = size(lu)
Expand All @@ -551,6 +566,8 @@ function det(A::Matrix)
return det(lud(A))
end

det(x::Number) = x

function (\){T<:BlasFloat}(lu::LUDense{T}, B::StridedVecOrMat{T})
if lu.info > 0; throw(LAPACK.SingularException(info)); end
LAPACK.getrs!('N', lu.lu, lu.ipiv, copy(B))
Expand Down Expand Up @@ -585,6 +602,7 @@ function factors{T<:BlasFloat}(qrd::QRDense{T})
end

qr{T<:Number}(x::StridedMatrix{T}) = factors(qrd(x))
qr(x::Number) = (one(x), x)

## Multiplication by Q from the QR decomposition
(*){T<:BlasFloat}(A::QRDense{T}, B::StridedVecOrMat{T}) =
Expand Down Expand Up @@ -688,8 +706,9 @@ function eig{T<:BlasFloat}(A::StridedMatrix{T}, vecs::Bool)
end

eig{T<:Integer}(x::StridedMatrix{T}, vecs::Bool) = eig(float64(x), vecs)
eig(x::StridedMatrix) = eig(x, true)
eigvals(x::StridedMatrix) = eig(x, false)
eig(x::Number, vecs::Bool) = vecs ? (x, one(x)) : x
eig(x) = eig(x, true)
eigvals(x) = eig(x, false)

# This is the svd based on the LAPACK GESVD, which is slower, but takes
# lesser memory. It should be made available through a keyword argument
Expand Down Expand Up @@ -721,8 +740,9 @@ function svd{T<:BlasFloat}(A::StridedMatrix{T},vecs::Bool,thin::Bool)
end

svd{T<:Integer}(x::StridedMatrix{T},vecs,thin) = svd(float64(x),vecs,thin)
svd(A::StridedMatrix) = svd(A,true,false)
svd(A::StridedMatrix, thin::Bool) = svd(A,true,thin)
svd(x::Number,vecs::Bool,thin::Bool) = vecs ? (x==0?one(x):x/abs(x),abs(x),one(x)) : ([],abs(x),[])
svd(A) = svd(A,true,false)
svd(A, thin::Bool) = svd(A,true,thin)
svdvals(A) = svd(A,false,true)[2]

function (\){T<:BlasFloat}(A::StridedMatrix{T}, B::StridedVecOrMat{T})
Expand Down Expand Up @@ -776,6 +796,7 @@ function pinv{T<:BlasFloat}(A::StridedMatrix{T})
end
pinv{T<:Integer}(A::StridedMatrix{T}) = pinv(float(A))
pinv(a::StridedVector) = pinv(reshape(a, length(a), 1))
pinv(x::Number) = one(x)/x

## Basis for null space
function null{T<:BlasFloat}(A::StridedMatrix{T})
Expand Down
2 changes: 2 additions & 0 deletions base/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ function dot(x::Vector, y::Vector)
s
end

dot(x::Number, y::Number) = conj(x) * y

# Matrix-vector multiplication

function (*){T<:BlasFloat}(A::StridedMatrix{T},
Expand Down
1 change: 0 additions & 1 deletion base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ ctranspose(x::Number) = conj(x)
inv(x::Number) = one(x)/x
angle(z::Real) = atan2(zero(z), z)

# TODO: should we really treat numbers as iterable?
start(a::Real) = a
next(a::Real, i) = (a, a+1)
done(a::Real, i) = (i > a)
Expand Down

0 comments on commit b8f47e8

Please sign in to comment.