Skip to content

Commit

Permalink
remove use of functors for sparse vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC committed Mar 30, 2016
1 parent 7cbc133 commit 988bd8c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 48 deletions.
1 change: 0 additions & 1 deletion base/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

module SparseArrays

using Base: Func, AddFun, OrFun, ConjFun, IdFun
using Base.Sort: Forward
using Base.LinAlg: AbstractTriangular, PosDefException

Expand Down
85 changes: 38 additions & 47 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@

### Common definitions

import Base: Func, AddFun, MulFun, MaxFun, MinFun, SubFun, sort

immutable ComplexFun <: Func{2} end
(::ComplexFun)(x::Real, y::Real) = complex(x, y)

immutable DotFun <: Func{2} end
(::DotFun)(x::Number, y::Number) = conj(x) * y

typealias UnaryOp Union{Function, Func{1}}
typealias BinaryOp Union{Function, Func{2}}
import Base: sort

### The SparseVector

Expand Down Expand Up @@ -64,7 +55,7 @@ function _sparsevector!{Ti<:Integer}(I::Vector{Ti}, V::Vector, len::Integer)
SparseVector(len, I, V)
end

function _sparsevector!{Tv,Ti<:Integer}(I::Vector{Ti}, V::Vector{Tv}, len::Integer, combine::BinaryOp)
function _sparsevector!{Tv,Ti<:Integer}(I::Vector{Ti}, V::Vector{Tv}, len::Integer, combine::Function)
if !isempty(I)
p = sortperm(I)
permute!(I, p)
Expand Down Expand Up @@ -125,7 +116,7 @@ Duplicates are combined using the `combine` function, which defaults to
`+` if no `combine` argument is provided, unless the elements of `V` are Booleans
in which case `combine` defaults to `|`.
"""
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, combine::BinaryOp)
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, combine::Function)
length(I) == length(V) ||
throw(ArgumentError("index and value vectors must be the same length"))
len = 0
Expand All @@ -138,7 +129,7 @@ function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv},
_sparsevector!(collect(Ti, I), collect(Tv, V), len, combine)
end

function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, len::Integer, combine::BinaryOp)
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, len::Integer, combine::Function)
length(I) == length(V) ||
throw(ArgumentError("index and value vectors must be the same length"))
maxi = convert(Ti, len)
Expand All @@ -149,23 +140,23 @@ function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv},
end

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Number, AbstractVector}) =
sparsevec(I, V, AddFun())
sparsevec(I, V, +)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Number, AbstractVector},
len::Integer) =
sparsevec(I, V, len, AddFun())
sparsevec(I, V, len, +)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Bool, AbstractVector{Bool}}) =
sparsevec(I, V, OrFun())
sparsevec(I, V, |)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Bool, AbstractVector{Bool}},
len::Integer) =
sparsevec(I, V, len, OrFun())
sparsevec(I, V, len, |)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, combine::BinaryOp) =
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, combine::Function) =
sparsevec(I, fill(v, length(I)), combine)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, len::Integer, combine::BinaryOp) =
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, len::Integer, combine::Function) =
sparsevec(I, fill(v, length(I)), len, combine)


Expand Down Expand Up @@ -887,7 +878,7 @@ end
# 1: f(nz, nz) -> z/nz, f(z, nz) -> nz, f(nz, z) -> nz
# 2: f(nz, nz) -> z/nz, f(z, nz) -> z/nz, f(nz, z) -> z/nz

function _binarymap{Tx,Ty}(f::BinaryOp,
function _binarymap{Tx,Ty}(f::Function,
x::AbstractSparseVector{Tx},
y::AbstractSparseVector{Ty},
mode::Int)
Expand Down Expand Up @@ -925,7 +916,7 @@ function _binarymap{Tx,Ty}(f::BinaryOp,
return SparseVector(n, rind, rval)
end

function _binarymap_mode_0!(f::BinaryOp, mx::Int, my::Int,
function _binarymap_mode_0!(f::Function, mx::Int, my::Int,
xnzind, xnzval, ynzind, ynzval, rind, rval)
# f(nz, nz) -> nz, f(z, nz) -> z, f(nz, z) -> z
ir = 0; ix = 1; iy = 1
Expand All @@ -945,7 +936,7 @@ function _binarymap_mode_0!(f::BinaryOp, mx::Int, my::Int,
return ir
end

function _binarymap_mode_1!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
function _binarymap_mode_1!{Tx,Ty}(f::Function, mx::Int, my::Int,
xnzind, xnzval::AbstractVector{Tx},
ynzind, ynzval::AbstractVector{Ty},
rind, rval)
Expand Down Expand Up @@ -983,7 +974,7 @@ function _binarymap_mode_1!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
return ir
end

function _binarymap_mode_2!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
function _binarymap_mode_2!{Tx,Ty}(f::Function, mx::Int, my::Int,
xnzind, xnzval::AbstractVector{Tx},
ynzind, ynzval::AbstractVector{Ty},
rind, rval)
Expand Down Expand Up @@ -1029,7 +1020,7 @@ function _binarymap_mode_2!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
return ir
end

function _binarymap{Tx,Ty}(f::BinaryOp,
function _binarymap{Tx,Ty}(f::Function,
x::AbstractVector{Tx},
y::AbstractSparseVector{Ty},
mode::Int)
Expand Down Expand Up @@ -1072,7 +1063,7 @@ function _binarymap{Tx,Ty}(f::BinaryOp,
return dst
end

function _binarymap{Tx,Ty}(f::BinaryOp,
function _binarymap{Tx,Ty}(f::Function,
x::AbstractSparseVector{Tx},
y::AbstractVector{Ty},
mode::Int)
Expand Down Expand Up @@ -1118,13 +1109,13 @@ end

### Binary arithmetics: +, -, *

for (vop, fun, mode) in [(:_vadd, :AddFun, 1),
(:_vsub, :SubFun, 1),
(:_vmul, :MulFun, 0)]
for (vop, fun, mode) in [(:_vadd, :+, 1),
(:_vsub, :-, 1),
(:_vmul, :*, 0)]
@eval begin
$(vop)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun)(), x, y, $mode)
$(vop)(x::StridedVector, y::AbstractSparseVector) = _binarymap($(fun)(), x, y, $mode)
$(vop)(x::AbstractSparseVector, y::StridedVector) = _binarymap($(fun)(), x, y, $mode)
$(vop)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
$(vop)(x::StridedVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
$(vop)(x::AbstractSparseVector, y::StridedVector) = _binarymap($(fun), x, y, $mode)
end
end

Expand All @@ -1146,16 +1137,16 @@ end

# definition of other binary functions

for (op, fun, TF, mode) in [(:max, :MaxFun, :Real, 2),
(:min, :MinFun, :Real, 2),
(:complex, :ComplexFun, :Real, 1)]
for (op, TF, mode) in [(:max, :Real, 2),
(:min, :Real, 2),
(:complex, :Real, 1)]
@eval begin
$(op){Tx<:$(TF),Ty<:$(TF)}(x::AbstractSparseVector{Tx}, y::AbstractSparseVector{Ty}) =
_binarymap($(fun)(), x, y, $mode)
_binarymap($(op), x, y, $mode)
$(op){Tx<:$(TF),Ty<:$(TF)}(x::StridedVector{Tx}, y::AbstractSparseVector{Ty}) =
_binarymap($(fun)(), x, y, $mode)
_binarymap($(op), x, y, $mode)
$(op){Tx<:$(TF),Ty<:$(TF)}(x::AbstractSparseVector{Tx}, y::StridedVector{Ty}) =
_binarymap($(fun)(), x, y, $mode)
_binarymap($(op), x, y, $mode)
end
end

Expand Down Expand Up @@ -1192,8 +1183,8 @@ vecnorm(x::AbstractSparseVector, p::Real=2) = vecnorm(nonzeros(x), p)

# Transpose
# (The only sparse matrix structure in base is CSC, so a one-row sparse matrix is worse than dense)
transpose(x::SparseVector) = _ct(IdFun(), x)
ctranspose(x::SparseVector) = _ct(ConjFun(), x)
transpose(x::SparseVector) = _ct(identity, x)
ctranspose(x::SparseVector) = _ct(conj, x)
function _ct{T}(f, x::SparseVector{T})
isempty(x) && return Array(T, 1, 0)
A = zeros(T, 1, length(x))
Expand Down Expand Up @@ -1277,7 +1268,7 @@ function dot{Tx<:Number,Ty<:Number}(x::AbstractSparseVector{Tx}, y::AbstractVect
return s
end

function _spdot(f::BinaryOp,
function _spdot(f::Function,
xj::Int, xj_last::Int, xnzind, xnzval,
yj::Int, yj_last::Int, ynzind, ynzval)
# dot product between ranges of non-zeros,
Expand Down Expand Up @@ -1308,7 +1299,7 @@ function dot{Tx<:Number,Ty<:Number}(x::AbstractSparseVector{Tx}, y::AbstractSpar
xnzval = nonzeros(x)
ynzval = nonzeros(y)

_spdot(DotFun(),
_spdot(dot,
1, length(xnzind), xnzind, xnzval,
1, length(ynzind), ynzind, ynzval)
end
Expand Down Expand Up @@ -1459,15 +1450,15 @@ At_mul_B!{Tx,Ty}(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVect
At_mul_B!(one(Tx), A, x, zero(Ty), y)

At_mul_B!{Tx,Ty}::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}, β::Number, y::StridedVector{Ty}) =
_At_or_Ac_mul_B!(MulFun(), α, A, x, β, y)
_At_or_Ac_mul_B!(*, α, A, x, β, y)

Ac_mul_B!{Tx,Ty}(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) =
Ac_mul_B!(one(Tx), A, x, zero(Ty), y)

Ac_mul_B!{Tx,Ty}::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}, β::Number, y::StridedVector{Ty}) =
_At_or_Ac_mul_B!(DotFun(), α, A, x, β, y)
_At_or_Ac_mul_B!(dot, α, A, x, β, y)

function _At_or_Ac_mul_B!{Tx,Ty}(tfun::BinaryOp,
function _At_or_Ac_mul_B!{Tx,Ty}(tfun::Function,
α::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx},
β::Number, y::StridedVector{Ty})
m, n = size(A)
Expand Down Expand Up @@ -1504,12 +1495,12 @@ function *(A::SparseMatrixCSC, x::AbstractSparseVector)
end

At_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) =
_At_or_Ac_mul_B(MulFun(), A, x)
_At_or_Ac_mul_B(*, A, x)

Ac_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) =
_At_or_Ac_mul_B(DotFun(), A, x)
_At_or_Ac_mul_B(dot, A, x)

function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::BinaryOp, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX})
function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::Function, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX})
m, n = size(A)
length(x) == m || throw(DimensionMismatch())
Tv = promote_type(TvA, TvX)
Expand Down

0 comments on commit 988bd8c

Please sign in to comment.