Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove use of functors for sparse matrices and vectors #15696

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
remove use of functors for sparse vectors
  • Loading branch information
KristofferC committed Mar 30, 2016
commit ade0dedf16fb863ab4e942066a46406b785cae23
1 change: 0 additions & 1 deletion base/sparse.jl
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@

module SparseArrays

using Base: Func, AddFun, OrFun, ConjFun, IdFun
using Base.Sort: Forward
using Base.LinAlg: AbstractTriangular, PosDefException

85 changes: 38 additions & 47 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
@@ -2,16 +2,7 @@

### Common definitions

import Base: Func, AddFun, MulFun, MaxFun, MinFun, SubFun, sort

immutable ComplexFun <: Func{2} end
(::ComplexFun)(x::Real, y::Real) = complex(x, y)

immutable DotFun <: Func{2} end
(::DotFun)(x::Number, y::Number) = conj(x) * y

typealias UnaryOp Union{Function, Func{1}}
typealias BinaryOp Union{Function, Func{2}}
import Base: sort

### The SparseVector

@@ -64,7 +55,7 @@ function _sparsevector!{Ti<:Integer}(I::Vector{Ti}, V::Vector, len::Integer)
SparseVector(len, I, V)
end

function _sparsevector!{Tv,Ti<:Integer}(I::Vector{Ti}, V::Vector{Tv}, len::Integer, combine::BinaryOp)
function _sparsevector!{Tv,Ti<:Integer}(I::Vector{Ti}, V::Vector{Tv}, len::Integer, combine::Function)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any potential risk if the caller passes a function expecting less/more than two arguments here? I wouldn't expect any, but better be sure of that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an internal function so it is nothing a user will call directly. Also, if the user passes a function of the wrong number of arguments it will just throw.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine then.

if !isempty(I)
p = sortperm(I)
permute!(I, p)
@@ -125,7 +116,7 @@ Duplicates are combined using the `combine` function, which defaults to
`+` if no `combine` argument is provided, unless the elements of `V` are Booleans
in which case `combine` defaults to `|`.
"""
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, combine::BinaryOp)
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, combine::Function)
length(I) == length(V) ||
throw(ArgumentError("index and value vectors must be the same length"))
len = 0
@@ -138,7 +129,7 @@ function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv},
_sparsevector!(collect(Ti, I), collect(Tv, V), len, combine)
end

function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, len::Integer, combine::BinaryOp)
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, len::Integer, combine::Function)
length(I) == length(V) ||
throw(ArgumentError("index and value vectors must be the same length"))
maxi = convert(Ti, len)
@@ -149,23 +140,23 @@ function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv},
end

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Number, AbstractVector}) =
sparsevec(I, V, AddFun())
sparsevec(I, V, +)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Number, AbstractVector},
len::Integer) =
sparsevec(I, V, len, AddFun())
sparsevec(I, V, len, +)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Bool, AbstractVector{Bool}}) =
sparsevec(I, V, OrFun())
sparsevec(I, V, |)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Bool, AbstractVector{Bool}},
len::Integer) =
sparsevec(I, V, len, OrFun())
sparsevec(I, V, len, |)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, combine::BinaryOp) =
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, combine::Function) =
sparsevec(I, fill(v, length(I)), combine)

sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, len::Integer, combine::BinaryOp) =
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, len::Integer, combine::Function) =
sparsevec(I, fill(v, length(I)), len, combine)


@@ -887,7 +878,7 @@ end
# 1: f(nz, nz) -> z/nz, f(z, nz) -> nz, f(nz, z) -> nz
# 2: f(nz, nz) -> z/nz, f(z, nz) -> z/nz, f(nz, z) -> z/nz

function _binarymap{Tx,Ty}(f::BinaryOp,
function _binarymap{Tx,Ty}(f::Function,
x::AbstractSparseVector{Tx},
y::AbstractSparseVector{Ty},
mode::Int)
@@ -925,7 +916,7 @@ function _binarymap{Tx,Ty}(f::BinaryOp,
return SparseVector(n, rind, rval)
end

function _binarymap_mode_0!(f::BinaryOp, mx::Int, my::Int,
function _binarymap_mode_0!(f::Function, mx::Int, my::Int,
xnzind, xnzval, ynzind, ynzval, rind, rval)
# f(nz, nz) -> nz, f(z, nz) -> z, f(nz, z) -> z
ir = 0; ix = 1; iy = 1
@@ -945,7 +936,7 @@ function _binarymap_mode_0!(f::BinaryOp, mx::Int, my::Int,
return ir
end

function _binarymap_mode_1!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
function _binarymap_mode_1!{Tx,Ty}(f::Function, mx::Int, my::Int,
xnzind, xnzval::AbstractVector{Tx},
ynzind, ynzval::AbstractVector{Ty},
rind, rval)
@@ -983,7 +974,7 @@ function _binarymap_mode_1!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
return ir
end

function _binarymap_mode_2!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
function _binarymap_mode_2!{Tx,Ty}(f::Function, mx::Int, my::Int,
xnzind, xnzval::AbstractVector{Tx},
ynzind, ynzval::AbstractVector{Ty},
rind, rval)
@@ -1029,7 +1020,7 @@ function _binarymap_mode_2!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int,
return ir
end

function _binarymap{Tx,Ty}(f::BinaryOp,
function _binarymap{Tx,Ty}(f::Function,
x::AbstractVector{Tx},
y::AbstractSparseVector{Ty},
mode::Int)
@@ -1072,7 +1063,7 @@ function _binarymap{Tx,Ty}(f::BinaryOp,
return dst
end

function _binarymap{Tx,Ty}(f::BinaryOp,
function _binarymap{Tx,Ty}(f::Function,
x::AbstractSparseVector{Tx},
y::AbstractVector{Ty},
mode::Int)
@@ -1118,13 +1109,13 @@ end

### Binary arithmetics: +, -, *

for (vop, fun, mode) in [(:_vadd, :AddFun, 1),
(:_vsub, :SubFun, 1),
(:_vmul, :MulFun, 0)]
for (vop, fun, mode) in [(:_vadd, :+, 1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also remove duplicate symbols here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What duplicate symbols? You mean the functions a few lines down can be merged with these?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carry on, I thought my comment from above applied somewhere else too, and after you updated the commits I got confused.

(:_vsub, :-, 1),
(:_vmul, :*, 0)]
@eval begin
$(vop)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun)(), x, y, $mode)
$(vop)(x::StridedVector, y::AbstractSparseVector) = _binarymap($(fun)(), x, y, $mode)
$(vop)(x::AbstractSparseVector, y::StridedVector) = _binarymap($(fun)(), x, y, $mode)
$(vop)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
$(vop)(x::StridedVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
$(vop)(x::AbstractSparseVector, y::StridedVector) = _binarymap($(fun), x, y, $mode)
end
end

@@ -1146,16 +1137,16 @@ end

# definition of other binary functions

for (op, fun, TF, mode) in [(:max, :MaxFun, :Real, 2),
(:min, :MinFun, :Real, 2),
(:complex, :ComplexFun, :Real, 1)]
for (op, TF, mode) in [(:max, :Real, 2),
(:min, :Real, 2),
(:complex, :Real, 1)]
@eval begin
$(op){Tx<:$(TF),Ty<:$(TF)}(x::AbstractSparseVector{Tx}, y::AbstractSparseVector{Ty}) =
_binarymap($(fun)(), x, y, $mode)
_binarymap($(op), x, y, $mode)
$(op){Tx<:$(TF),Ty<:$(TF)}(x::StridedVector{Tx}, y::AbstractSparseVector{Ty}) =
_binarymap($(fun)(), x, y, $mode)
_binarymap($(op), x, y, $mode)
$(op){Tx<:$(TF),Ty<:$(TF)}(x::AbstractSparseVector{Tx}, y::StridedVector{Ty}) =
_binarymap($(fun)(), x, y, $mode)
_binarymap($(op), x, y, $mode)
end
end

@@ -1192,8 +1183,8 @@ vecnorm(x::AbstractSparseVector, p::Real=2) = vecnorm(nonzeros(x), p)

# Transpose
# (The only sparse matrix structure in base is CSC, so a one-row sparse matrix is worse than dense)
transpose(x::SparseVector) = _ct(IdFun(), x)
ctranspose(x::SparseVector) = _ct(ConjFun(), x)
transpose(x::SparseVector) = _ct(identity, x)
ctranspose(x::SparseVector) = _ct(conj, x)
function _ct{T}(f, x::SparseVector{T})
isempty(x) && return Array(T, 1, 0)
A = zeros(T, 1, length(x))
@@ -1277,7 +1268,7 @@ function dot{Tx<:Number,Ty<:Number}(x::AbstractSparseVector{Tx}, y::AbstractVect
return s
end

function _spdot(f::BinaryOp,
function _spdot(f::Function,
xj::Int, xj_last::Int, xnzind, xnzval,
yj::Int, yj_last::Int, ynzind, ynzval)
# dot product between ranges of non-zeros,
@@ -1308,7 +1299,7 @@ function dot{Tx<:Number,Ty<:Number}(x::AbstractSparseVector{Tx}, y::AbstractSpar
xnzval = nonzeros(x)
ynzval = nonzeros(y)

_spdot(DotFun(),
_spdot(dot,
1, length(xnzind), xnzind, xnzval,
1, length(ynzind), ynzind, ynzval)
end
@@ -1459,15 +1450,15 @@ At_mul_B!{Tx,Ty}(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVect
At_mul_B!(one(Tx), A, x, zero(Ty), y)

At_mul_B!{Tx,Ty}(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}, β::Number, y::StridedVector{Ty}) =
_At_or_Ac_mul_B!(MulFun(), α, A, x, β, y)
_At_or_Ac_mul_B!(*, α, A, x, β, y)

Ac_mul_B!{Tx,Ty}(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) =
Ac_mul_B!(one(Tx), A, x, zero(Ty), y)

Ac_mul_B!{Tx,Ty}(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}, β::Number, y::StridedVector{Ty}) =
_At_or_Ac_mul_B!(DotFun(), α, A, x, β, y)
_At_or_Ac_mul_B!(dot, α, A, x, β, y)

function _At_or_Ac_mul_B!{Tx,Ty}(tfun::BinaryOp,
function _At_or_Ac_mul_B!{Tx,Ty}(tfun::Function,
α::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx},
β::Number, y::StridedVector{Ty})
m, n = size(A)
@@ -1504,12 +1495,12 @@ function *(A::SparseMatrixCSC, x::AbstractSparseVector)
end

At_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) =
_At_or_Ac_mul_B(MulFun(), A, x)
_At_or_Ac_mul_B(*, A, x)

Ac_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) =
_At_or_Ac_mul_B(DotFun(), A, x)
_At_or_Ac_mul_B(dot, A, x)

function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::BinaryOp, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX})
function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::Function, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX})
m, n = size(A)
length(x) == m || throw(DimensionMismatch())
Tv = promote_type(TvA, TvX)