- Sponsor
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove use of functors for sparse matrices and vectors #15696
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,7 @@ | |
|
||
### Common definitions | ||
|
||
import Base: Func, AddFun, MulFun, MaxFun, MinFun, SubFun, sort | ||
|
||
immutable ComplexFun <: Func{2} end | ||
(::ComplexFun)(x::Real, y::Real) = complex(x, y) | ||
|
||
immutable DotFun <: Func{2} end | ||
(::DotFun)(x::Number, y::Number) = conj(x) * y | ||
|
||
typealias UnaryOp Union{Function, Func{1}} | ||
typealias BinaryOp Union{Function, Func{2}} | ||
import Base: sort | ||
|
||
### The SparseVector | ||
|
||
|
@@ -64,7 +55,7 @@ function _sparsevector!{Ti<:Integer}(I::Vector{Ti}, V::Vector, len::Integer) | |
SparseVector(len, I, V) | ||
end | ||
|
||
function _sparsevector!{Tv,Ti<:Integer}(I::Vector{Ti}, V::Vector{Tv}, len::Integer, combine::BinaryOp) | ||
function _sparsevector!{Tv,Ti<:Integer}(I::Vector{Ti}, V::Vector{Tv}, len::Integer, combine::Function) | ||
if !isempty(I) | ||
p = sortperm(I) | ||
permute!(I, p) | ||
|
@@ -125,7 +116,7 @@ Duplicates are combined using the `combine` function, which defaults to | |
`+` if no `combine` argument is provided, unless the elements of `V` are Booleans | ||
in which case `combine` defaults to `|`. | ||
""" | ||
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, combine::BinaryOp) | ||
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, combine::Function) | ||
length(I) == length(V) || | ||
throw(ArgumentError("index and value vectors must be the same length")) | ||
len = 0 | ||
|
@@ -138,7 +129,7 @@ function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, | |
_sparsevector!(collect(Ti, I), collect(Tv, V), len, combine) | ||
end | ||
|
||
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, len::Integer, combine::BinaryOp) | ||
function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, len::Integer, combine::Function) | ||
length(I) == length(V) || | ||
throw(ArgumentError("index and value vectors must be the same length")) | ||
maxi = convert(Ti, len) | ||
|
@@ -149,23 +140,23 @@ function sparsevec{Tv,Ti<:Integer}(I::AbstractVector{Ti}, V::AbstractVector{Tv}, | |
end | ||
|
||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Number, AbstractVector}) = | ||
sparsevec(I, V, AddFun()) | ||
sparsevec(I, V, +) | ||
|
||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Number, AbstractVector}, | ||
len::Integer) = | ||
sparsevec(I, V, len, AddFun()) | ||
sparsevec(I, V, len, +) | ||
|
||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Bool, AbstractVector{Bool}}) = | ||
sparsevec(I, V, OrFun()) | ||
sparsevec(I, V, |) | ||
|
||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, V::Union{Bool, AbstractVector{Bool}}, | ||
len::Integer) = | ||
sparsevec(I, V, len, OrFun()) | ||
sparsevec(I, V, len, |) | ||
|
||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, combine::BinaryOp) = | ||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, combine::Function) = | ||
sparsevec(I, fill(v, length(I)), combine) | ||
|
||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, len::Integer, combine::BinaryOp) = | ||
sparsevec{Ti<:Integer}(I::AbstractVector{Ti}, v::Number, len::Integer, combine::Function) = | ||
sparsevec(I, fill(v, length(I)), len, combine) | ||
|
||
|
||
|
@@ -887,7 +878,7 @@ end | |
# 1: f(nz, nz) -> z/nz, f(z, nz) -> nz, f(nz, z) -> nz | ||
# 2: f(nz, nz) -> z/nz, f(z, nz) -> z/nz, f(nz, z) -> z/nz | ||
|
||
function _binarymap{Tx,Ty}(f::BinaryOp, | ||
function _binarymap{Tx,Ty}(f::Function, | ||
x::AbstractSparseVector{Tx}, | ||
y::AbstractSparseVector{Ty}, | ||
mode::Int) | ||
|
@@ -925,7 +916,7 @@ function _binarymap{Tx,Ty}(f::BinaryOp, | |
return SparseVector(n, rind, rval) | ||
end | ||
|
||
function _binarymap_mode_0!(f::BinaryOp, mx::Int, my::Int, | ||
function _binarymap_mode_0!(f::Function, mx::Int, my::Int, | ||
xnzind, xnzval, ynzind, ynzval, rind, rval) | ||
# f(nz, nz) -> nz, f(z, nz) -> z, f(nz, z) -> z | ||
ir = 0; ix = 1; iy = 1 | ||
|
@@ -945,7 +936,7 @@ function _binarymap_mode_0!(f::BinaryOp, mx::Int, my::Int, | |
return ir | ||
end | ||
|
||
function _binarymap_mode_1!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int, | ||
function _binarymap_mode_1!{Tx,Ty}(f::Function, mx::Int, my::Int, | ||
xnzind, xnzval::AbstractVector{Tx}, | ||
ynzind, ynzval::AbstractVector{Ty}, | ||
rind, rval) | ||
|
@@ -983,7 +974,7 @@ function _binarymap_mode_1!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int, | |
return ir | ||
end | ||
|
||
function _binarymap_mode_2!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int, | ||
function _binarymap_mode_2!{Tx,Ty}(f::Function, mx::Int, my::Int, | ||
xnzind, xnzval::AbstractVector{Tx}, | ||
ynzind, ynzval::AbstractVector{Ty}, | ||
rind, rval) | ||
|
@@ -1029,7 +1020,7 @@ function _binarymap_mode_2!{Tx,Ty}(f::BinaryOp, mx::Int, my::Int, | |
return ir | ||
end | ||
|
||
function _binarymap{Tx,Ty}(f::BinaryOp, | ||
function _binarymap{Tx,Ty}(f::Function, | ||
x::AbstractVector{Tx}, | ||
y::AbstractSparseVector{Ty}, | ||
mode::Int) | ||
|
@@ -1072,7 +1063,7 @@ function _binarymap{Tx,Ty}(f::BinaryOp, | |
return dst | ||
end | ||
|
||
function _binarymap{Tx,Ty}(f::BinaryOp, | ||
function _binarymap{Tx,Ty}(f::Function, | ||
x::AbstractSparseVector{Tx}, | ||
y::AbstractVector{Ty}, | ||
mode::Int) | ||
|
@@ -1118,13 +1109,13 @@ end | |
|
||
### Binary arithmetics: +, -, * | ||
|
||
for (vop, fun, mode) in [(:_vadd, :AddFun, 1), | ||
(:_vsub, :SubFun, 1), | ||
(:_vmul, :MulFun, 0)] | ||
for (vop, fun, mode) in [(:_vadd, :+, 1), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can also remove duplicate symbols here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What duplicate symbols? You mean the functions a few lines down can be merged with these? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Carry on, I thought my comment from above applied somewhere else too, and after you updated the commits I got confused. |
||
(:_vsub, :-, 1), | ||
(:_vmul, :*, 0)] | ||
@eval begin | ||
$(vop)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun)(), x, y, $mode) | ||
$(vop)(x::StridedVector, y::AbstractSparseVector) = _binarymap($(fun)(), x, y, $mode) | ||
$(vop)(x::AbstractSparseVector, y::StridedVector) = _binarymap($(fun)(), x, y, $mode) | ||
$(vop)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode) | ||
$(vop)(x::StridedVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode) | ||
$(vop)(x::AbstractSparseVector, y::StridedVector) = _binarymap($(fun), x, y, $mode) | ||
end | ||
end | ||
|
||
|
@@ -1146,16 +1137,16 @@ end | |
|
||
# definition of other binary functions | ||
|
||
for (op, fun, TF, mode) in [(:max, :MaxFun, :Real, 2), | ||
(:min, :MinFun, :Real, 2), | ||
(:complex, :ComplexFun, :Real, 1)] | ||
for (op, TF, mode) in [(:max, :Real, 2), | ||
(:min, :Real, 2), | ||
(:complex, :Real, 1)] | ||
@eval begin | ||
$(op){Tx<:$(TF),Ty<:$(TF)}(x::AbstractSparseVector{Tx}, y::AbstractSparseVector{Ty}) = | ||
_binarymap($(fun)(), x, y, $mode) | ||
_binarymap($(op), x, y, $mode) | ||
$(op){Tx<:$(TF),Ty<:$(TF)}(x::StridedVector{Tx}, y::AbstractSparseVector{Ty}) = | ||
_binarymap($(fun)(), x, y, $mode) | ||
_binarymap($(op), x, y, $mode) | ||
$(op){Tx<:$(TF),Ty<:$(TF)}(x::AbstractSparseVector{Tx}, y::StridedVector{Ty}) = | ||
_binarymap($(fun)(), x, y, $mode) | ||
_binarymap($(op), x, y, $mode) | ||
end | ||
end | ||
|
||
|
@@ -1192,8 +1183,8 @@ vecnorm(x::AbstractSparseVector, p::Real=2) = vecnorm(nonzeros(x), p) | |
|
||
# Transpose | ||
# (The only sparse matrix structure in base is CSC, so a one-row sparse matrix is worse than dense) | ||
transpose(x::SparseVector) = _ct(IdFun(), x) | ||
ctranspose(x::SparseVector) = _ct(ConjFun(), x) | ||
transpose(x::SparseVector) = _ct(identity, x) | ||
ctranspose(x::SparseVector) = _ct(conj, x) | ||
function _ct{T}(f, x::SparseVector{T}) | ||
isempty(x) && return Array(T, 1, 0) | ||
A = zeros(T, 1, length(x)) | ||
|
@@ -1277,7 +1268,7 @@ function dot{Tx<:Number,Ty<:Number}(x::AbstractSparseVector{Tx}, y::AbstractVect | |
return s | ||
end | ||
|
||
function _spdot(f::BinaryOp, | ||
function _spdot(f::Function, | ||
xj::Int, xj_last::Int, xnzind, xnzval, | ||
yj::Int, yj_last::Int, ynzind, ynzval) | ||
# dot product between ranges of non-zeros, | ||
|
@@ -1308,7 +1299,7 @@ function dot{Tx<:Number,Ty<:Number}(x::AbstractSparseVector{Tx}, y::AbstractSpar | |
xnzval = nonzeros(x) | ||
ynzval = nonzeros(y) | ||
|
||
_spdot(DotFun(), | ||
_spdot(dot, | ||
1, length(xnzind), xnzind, xnzval, | ||
1, length(ynzind), ynzind, ynzval) | ||
end | ||
|
@@ -1459,15 +1450,15 @@ At_mul_B!{Tx,Ty}(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVect | |
At_mul_B!(one(Tx), A, x, zero(Ty), y) | ||
|
||
At_mul_B!{Tx,Ty}(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}, β::Number, y::StridedVector{Ty}) = | ||
_At_or_Ac_mul_B!(MulFun(), α, A, x, β, y) | ||
_At_or_Ac_mul_B!(*, α, A, x, β, y) | ||
|
||
Ac_mul_B!{Tx,Ty}(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) = | ||
Ac_mul_B!(one(Tx), A, x, zero(Ty), y) | ||
|
||
Ac_mul_B!{Tx,Ty}(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}, β::Number, y::StridedVector{Ty}) = | ||
_At_or_Ac_mul_B!(DotFun(), α, A, x, β, y) | ||
_At_or_Ac_mul_B!(dot, α, A, x, β, y) | ||
|
||
function _At_or_Ac_mul_B!{Tx,Ty}(tfun::BinaryOp, | ||
function _At_or_Ac_mul_B!{Tx,Ty}(tfun::Function, | ||
α::Number, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}, | ||
β::Number, y::StridedVector{Ty}) | ||
m, n = size(A) | ||
|
@@ -1504,12 +1495,12 @@ function *(A::SparseMatrixCSC, x::AbstractSparseVector) | |
end | ||
|
||
At_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = | ||
_At_or_Ac_mul_B(MulFun(), A, x) | ||
_At_or_Ac_mul_B(*, A, x) | ||
|
||
Ac_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = | ||
_At_or_Ac_mul_B(DotFun(), A, x) | ||
_At_or_Ac_mul_B(dot, A, x) | ||
|
||
function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::BinaryOp, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX}) | ||
function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::Function, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX}) | ||
m, n = size(A) | ||
length(x) == m || throw(DimensionMismatch()) | ||
Tv = promote_type(TvA, TvX) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any potential risk if the caller passes a function expecting less/more than two arguments here? I wouldn't expect any, but better be sure of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an internal function so it is nothing a user will call directly. Also, if the user passes a function of the wrong number of arguments it will just throw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine then.