From b8f47e82c0a5c0071b57980eb74a4d9f227c76cd Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 2 Jan 2013 10:25:53 -0500 Subject: [PATCH] extend linalg and array operations to work on Numbers when it makes sense and is unambiguous --- base/array.jl | 23 +++++++++++++++-------- base/linalg.jl | 15 +++++++++++++++ base/linalg_dense.jl | 29 +++++++++++++++++++++++++---- base/matmul.jl | 2 ++ base/number.jl | 1 - 5 files changed, 57 insertions(+), 13 deletions(-) diff --git a/base/array.jl b/base/array.jl index d8b95756351c2..58cb74afea6bd 100644 --- a/base/array.jl +++ b/base/array.jl @@ -1092,7 +1092,7 @@ end ## find ## -function nnz(a::StridedArray) +function nnz(a) n = 0 for i = 1:numel(a) n += bool(a[i]) ? 1 : 0 @@ -1101,7 +1101,7 @@ function nnz(a::StridedArray) end # returns the index of the first non-zero element, or 0 if all zeros -function findfirst(A::StridedArray) +function findfirst(A) for i = 1:length(A) if A[i] != 0 return i @@ -1111,7 +1111,7 @@ function findfirst(A::StridedArray) end # returns the index of the first matching element -function findfirst(A::StridedArray, v) +function findfirst(A, v) for i = 1:length(A) if A[i] == v return i @@ -1121,7 +1121,7 @@ function findfirst(A::StridedArray, v) end # returns the index of the first element for which the function returns true -function findfirst(testf::Function, A::StridedArray) +function findfirst(testf::Function, A) for i = 1:length(A) if testf(A[i]) return i @@ -1157,6 +1157,10 @@ function find(A::StridedArray) return I end +find(x::Number) = x == 0 ? Array(Int,0) : [1] +find(x::Bool) = x ? [1] : Array(Int,0) +find(testf::Function, x) = find(testf(x)) + findn(A::StridedVector) = find(A) function findn(A::StridedMatrix) @@ -1241,7 +1245,9 @@ function nonzeros{T}(A::StridedArray{T}) return V end -function findmax(a::StridedArray) +nonzeros(x::Number) = x == 0 ? Array(typeof(x),0) : [x] + +function findmax(a) m = typemin(eltype(a)) mi = 0 for i=1:length(a) @@ -1253,7 +1259,7 @@ function findmax(a::StridedArray) return (m, mi) end -function findmin(a::StridedArray) +function findmin(a) m = typemax(eltype(a)) mi = 0 for i=1:length(a) @@ -1264,8 +1270,9 @@ function findmin(a::StridedArray) end return (m, mi) end -indmax(a::StridedArray) = findmax(a)[2] -indmin(a::StridedArray) = findmin(a)[2] + +indmax(a) = findmax(a)[2] +indmin(a) = findmin(a)[2] ## Reductions ## diff --git a/base/linalg.jl b/base/linalg.jl index a643a1e75ecd8..981aed065aede 100644 --- a/base/linalg.jl +++ b/base/linalg.jl @@ -74,6 +74,10 @@ function norm(A::AbstractMatrix, p) end norm(A::AbstractMatrix) = norm(A, 2) + +norm(x::Number) = abs(x) +norm(x::Number, p) = abs(x) + rank(A::AbstractMatrix, tol::Real) = sum(svdvals(A) .> tol) function rank(A::AbstractMatrix) m,n = size(A) @@ -81,8 +85,10 @@ function rank(A::AbstractMatrix) sv = svdvals(A) sum(sv .> max(size(A,1),size(A,2))*eps(sv[1])) end +rank(x::Number) = x == 0 ? 0 : 1 trace(A::AbstractMatrix) = sum(diag(A)) +trace(x::Number) = x #kron(a::AbstractVector, b::AbstractVector) #kron{T,S}(a::AbstractMatrix{T}, b::AbstractMatrix{S}) @@ -109,6 +115,9 @@ function cond(a::AbstractMatrix, p) end end +cond(x::Number) = x == 0 ? Inf : 1 +cond(x::Number, p) = cond(x) + function issym(A::AbstractMatrix) m, n = size(A) if m != n; error("matrix must be square, got $(m)x$(n)"); end @@ -120,6 +129,8 @@ function issym(A::AbstractMatrix) return true end +issym(x::Number) = true + function ishermitian(A::AbstractMatrix) m, n = size(A) if m != n; error("matrix must be square, got $(m)x$(n)"); end @@ -131,6 +142,8 @@ function ishermitian(A::AbstractMatrix) return true end +ishermitian(x::Number) = isreal(x) + function istriu(A::AbstractMatrix) m, n = size(A) for j = 1:min(n,m-1), i = j+1:m @@ -151,6 +164,8 @@ function istril(A::AbstractMatrix) return true end +istriu(x::Number) = true +istril(x::Number) = true function linreg{T<:Number}(X::StridedVecOrMat{T}, y::Vector{T}) [ones(T, size(X,1)) X] \ y diff --git a/base/linalg_dense.jl b/base/linalg_dense.jl index c3dd02f53b231..1fbe25ab87849 100644 --- a/base/linalg_dense.jl +++ b/base/linalg_dense.jl @@ -19,6 +19,8 @@ isposdef{T<:BlasFloat}(A::Matrix{T}) = isposdef!(copy(A)) isposdef{T<:Number}(A::Matrix{T}, upper::Bool) = isposdef!(float64(A), upper) isposdef{T<:Number}(A::Matrix{T}) = isposdef!(float64(A)) +isposdef(x::Number) = isreal(x) && x > 0 + norm{T<:BlasFloat}(x::Vector{T}) = BLAS.nrm2(length(x), x, 1) function norm{T<:BlasFloat, TI<:Integer}(x::Vector{T}, rx::Union(Range1{TI},Range{TI})) @@ -137,6 +139,8 @@ end diagm(v) = diagm(v, 0) +diagm(x::Number) = (X = Array(typeof(x),1,1); X[1,1] = x; X) + function trace{T}(A::Matrix{T}) t = zero(T) for i=1:min(size(A)) @@ -165,6 +169,12 @@ function kron{T,S}(a::Matrix{T}, b::Matrix{S}) R end +kron(a::Number, b::Number) = a * b +kron(a::Vector, b::Number) = a * b +kron(a::Number, b::Vector) = a * b +kron(a::Matrix, b::Number) = a * b +kron(a::Number, b::Matrix) = a * b + function randsym(n) a = randn(n,n) for j=1:n-1, i=j+1:n @@ -235,6 +245,8 @@ function rref{T}(A::Matrix{T}) return U end +rref(x::Number) = one(x) + ## Destructive matrix exponential using algorithm from Higham, 2008, ## "Functions of Matrices: Theory and Computation", SIAM function expm!{T<:BlasFloat}(A::StridedMatrix{T}) @@ -377,6 +389,7 @@ end # Matrix exponential expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm!(copy(A)) expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A)) +expm(x::Number) = exp(x) ## Matrix factorizations and decompositions @@ -451,6 +464,7 @@ chold{T<:Number}(A::Matrix{T}) = chold(A, true) ## Matlab (and R) compatible chol{T<:Number}(A::Matrix{T}) = factors(chold(A)) +chol(x::Number) = isreal(x) && x >= 0 ? sqrt(x) : error("argument not positive-definite") type CholeskyDensePivoted{T<:BlasFloat} <: Factorization{T} LR::Matrix{T} @@ -537,6 +551,7 @@ lud{T<:Number}(A::Matrix{T}) = lud(float64(A)) ## Matlab-compatible lu{T<:Number}(A::Matrix{T}) = factors(lud(A)) +lu(x::Number) = (one(x), x) function det{T}(lu::LUDense{T}) m, n = size(lu) @@ -551,6 +566,8 @@ function det(A::Matrix) return det(lud(A)) end +det(x::Number) = x + function (\){T<:BlasFloat}(lu::LUDense{T}, B::StridedVecOrMat{T}) if lu.info > 0; throw(LAPACK.SingularException(info)); end LAPACK.getrs!('N', lu.lu, lu.ipiv, copy(B)) @@ -585,6 +602,7 @@ function factors{T<:BlasFloat}(qrd::QRDense{T}) end qr{T<:Number}(x::StridedMatrix{T}) = factors(qrd(x)) +qr(x::Number) = (one(x), x) ## Multiplication by Q from the QR decomposition (*){T<:BlasFloat}(A::QRDense{T}, B::StridedVecOrMat{T}) = @@ -688,8 +706,9 @@ function eig{T<:BlasFloat}(A::StridedMatrix{T}, vecs::Bool) end eig{T<:Integer}(x::StridedMatrix{T}, vecs::Bool) = eig(float64(x), vecs) -eig(x::StridedMatrix) = eig(x, true) -eigvals(x::StridedMatrix) = eig(x, false) +eig(x::Number, vecs::Bool) = vecs ? (x, one(x)) : x +eig(x) = eig(x, true) +eigvals(x) = eig(x, false) # This is the svd based on the LAPACK GESVD, which is slower, but takes # lesser memory. It should be made available through a keyword argument @@ -721,8 +740,9 @@ function svd{T<:BlasFloat}(A::StridedMatrix{T},vecs::Bool,thin::Bool) end svd{T<:Integer}(x::StridedMatrix{T},vecs,thin) = svd(float64(x),vecs,thin) -svd(A::StridedMatrix) = svd(A,true,false) -svd(A::StridedMatrix, thin::Bool) = svd(A,true,thin) +svd(x::Number,vecs::Bool,thin::Bool) = vecs ? (x==0?one(x):x/abs(x),abs(x),one(x)) : ([],abs(x),[]) +svd(A) = svd(A,true,false) +svd(A, thin::Bool) = svd(A,true,thin) svdvals(A) = svd(A,false,true)[2] function (\){T<:BlasFloat}(A::StridedMatrix{T}, B::StridedVecOrMat{T}) @@ -776,6 +796,7 @@ function pinv{T<:BlasFloat}(A::StridedMatrix{T}) end pinv{T<:Integer}(A::StridedMatrix{T}) = pinv(float(A)) pinv(a::StridedVector) = pinv(reshape(a, length(a), 1)) +pinv(x::Number) = one(x)/x ## Basis for null space function null{T<:BlasFloat}(A::StridedMatrix{T}) diff --git a/base/matmul.jl b/base/matmul.jl index f8f5b17264df4..5d0366438ddd8 100644 --- a/base/matmul.jl +++ b/base/matmul.jl @@ -60,6 +60,8 @@ function dot(x::Vector, y::Vector) s end +dot(x::Number, y::Number) = conj(x) * y + # Matrix-vector multiplication function (*){T<:BlasFloat}(A::StridedMatrix{T}, diff --git a/base/number.jl b/base/number.jl index 220b4c101a24c..2261718152234 100644 --- a/base/number.jl +++ b/base/number.jl @@ -34,7 +34,6 @@ ctranspose(x::Number) = conj(x) inv(x::Number) = one(x)/x angle(z::Real) = atan2(zero(z), z) -# TODO: should we really treat numbers as iterable? start(a::Real) = a next(a::Real, i) = (a, a+1) done(a::Real, i) = (i > a)