Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.7.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
HyperSparseMatrices = "c7efdb1c-7caa-4c7d-9b5e-9093f9323c7c"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -21,7 +20,6 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
ChainRulesCore = "1"
HyperSparseMatrices = "0.2"
MacroTools = "0.5"
Preferences = "1"
SSGraphBLAS_jll = "6.2.1"
Expand Down
4 changes: 1 addition & 3 deletions docs/src/binaryops.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ Internally functions are lowered like this:
```@repl
using SuiteSparseGraphBLAS

op = BinaryOp(+)

typedop = op(Int64, Int64)
typedop = binaryop(+, Int64, Int64)

eadd(GBVector([1,2]), GBVector([3,4]), typedop)
```
Expand Down
3 changes: 1 addition & 2 deletions docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ Operators are lowered from a Julia function to a container like `BinaryOp` or `S
using SuiteSparseGraphBLAS
```
```@repl operators
b = BinaryOp(+)
b(Int32)
b = binaryop(+, Int32)

s = Semiring(max, +)
s(Float64)
Expand Down
4 changes: 2 additions & 2 deletions docs/src/udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ GraphBLAS supports users to supply functions as operators. Constructors exported

- `UnaryOp(name::String, fn::Function, [type | types | ztype, xtype | ztypes, xtypes])`
- `BinaryOp(name::String, fn::Function, [type | types | ztype, xtype | ztypes, xtypes])`
- `Monoid(name::String, binop::Union{AbstractBinaryOp, GrB_BinaryOp}, id::T, terminal::T = nothing)`: all types must be the same.
- `Semiring(name::String, add::[GrB_Monoid | AbstractMonoid], mul::[GrB_BinaryOp | AbstractBinaryOp])`
- `Monoid(name::String, binop::Union{GrB_BinaryOp}, id::T, terminal::T = nothing)`: all types must be the same.
- `Semiring(name::String, add::[GrB_Monoid | AbstractMonoid], mul::GrB_BinaryOp)`

`GrB_` prefixed arguments are typed operators, such as the result of `UnaryOps.COS[Float64]`.
Type arguments may be single types or vectors of types.
Expand Down
4 changes: 1 addition & 3 deletions docs/src/unaryops.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ Internally functions are lowered like this:
```@repl
using SuiteSparseGraphBLAS

op = UnaryOp(sin)

typedop = op(Float64)
op = unaryop(sin, Float64)

map(typedop, GBVector([1.5, 0, pi]))
```
Expand Down
4 changes: 1 addition & 3 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ using Serialization
using StorageOrders

export ColMajor, RowMajor, storageorder #reexports from StorageOrders

using HyperSparseMatrices
include("abstracts.jl")
include("libutils.jl")

Expand Down Expand Up @@ -101,7 +99,7 @@ include("oriented.jl")
export SparseArrayCompat
export LibGraphBLAS
# export UnaryOps, BinaryOps, Monoids, Semirings #Submodules
export UnaryOp, BinaryOp, Monoid, Semiring #UDFs
export unaryop, binaryop, Monoid, semiring #UDFs
export Descriptor #Types
export gbset, gbget # global and object specific options.
# export xtype, ytype, ztype #Determine input/output types of operators
Expand Down
42 changes: 21 additions & 21 deletions src/abstractgbarray.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# AbstractGBArray functions:
function SparseArrays.nnz(A::AbsGBArrayOrTranspose)
function SparseArrays.nnz(A::GBArrayOrTranspose)
nvals = Ref{LibGraphBLAS.GrB_Index}()
@wraperror LibGraphBLAS.GrB_Matrix_nvals(nvals, gbpointer(parent(A)))
return Int64(nvals[])
end

Base.eltype(::Type{AbstractGBArray{T}}) where{T} = T
Base.eltype(::Type{GBArrayOrTranspose{T}}) where{T} = T

"""
empty!(v::GBVector)
Expand All @@ -14,32 +14,32 @@ Base.eltype(::Type{AbstractGBArray{T}}) where{T} = T
Clear all the entries from the GBArray.
Does not modify the type or dimensions.
"""
function Base.empty!(A::AbsGBArrayOrTranspose)
function Base.empty!(A::GBArrayOrTranspose)
@wraperror LibGraphBLAS.GrB_Matrix_clear(gbpointer(parent(A)))
return A
end

function Base.Matrix(A::AbstractGBMatrix)
function Base.Matrix(A::GBArrayOrTranspose)
sparsity = sparsitystatus(A)
T = copy(A) # We copy here to 1. avoid densifying A, and 2. to avoid destroying A.
return unpack!(T, Dense())
end

function Base.Vector(v::AbstractGBVector)
function Base.Vector(v::GBVectorOrTranspose)
sparsity = sparsitystatus(v)
T = copy(v) # avoid densifying v and destroying v.
return unpack!(T, Dense())
end

function SparseArrays.SparseMatrixCSC(A::AbstractGBArray)
function SparseArrays.SparseMatrixCSC(A::GBArrayOrTranspose)
sparsity = sparsitystatus(A)
T = copy(A) # avoid changing sparsity of A and destroying it.
return unpack!(T, SparseMatrixCSC)
end

function SparseArrays.SparseVector(v::AbstractGBVector)
function SparseArrays.SparseVector(v::GBVectorOrTranspose)
sparsity = sparsitystatus(v)
T = copy(A) # avoid changing sparsity of v and destroying it.
T = copy(v) # avoid changing sparsity of v and destroying it.
return unpack!(T, SparseVector)
end

Expand Down Expand Up @@ -94,7 +94,7 @@ for T ∈ valid_vec
function build(A::AbstractGBMatrix{$T}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector{$T};
combine = +
)
combine = BinaryOp(combine)($T)
combine = binaryop(combine, $T)
I isa Vector || (I = collect(I))
J isa Vector || (J = collect(J))
X isa Vector || (X = collect(X))
Expand Down Expand Up @@ -181,7 +181,7 @@ function build(
A::AbstractGBMatrix{T}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector{T};
combine = +
) where {T}
combine = BinaryOp(combine)(T)
combine = binaryop(combine, T)
I isa Vector || (I = collect(I))
J isa Vector || (J = collect(J))
X isa Vector || (X = collect(X))
Expand Down Expand Up @@ -314,7 +314,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`assign!`](@ref) except that
# Keywords
- `mask::Union{Nothing, GBMatrix} = nothing`: mask where
`size(M) == size(A)`.
- `accum::Union{Nothing, Function, AbstractBinaryOp} = nothing`: binary accumulator operation
- `accum::Union{Nothing, Function} = nothing`: binary accumulator operation
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
- `desc::Union{Nothing, Descriptor} = nothing`

Expand All @@ -325,7 +325,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`assign!`](@ref) except that
- `GrB_DIMENSION_MISMATCH`: If `size(A) != (max(I), max(J))` or `size(A) != size(mask)`.
"""
function subassign!(
C::AbstractGBArray, A::AbstractGBArray, I, J;
C::AbstractGBArray, A::GBArrayOrTranspose, I, J;
mask = nothing, accum = nothing, desc = nothing
)
I, ni = idx(I)
Expand Down Expand Up @@ -397,7 +397,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`subassign!`](@ref) except that
# Keywords
- `mask::Union{Nothing, GBMatrix} = nothing`: mask where
`size(M) == size(C)`.
- `accum::Union{Nothing, Function, AbstractBinaryOp} = nothing`: binary accumulator operation
- `accum::Union{Nothing, Function} = nothing`: binary accumulator operation
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
- `desc::Union{Nothing, Descriptor} = nothing`

Expand All @@ -408,7 +408,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`subassign!`](@ref) except that
- `GrB_DIMENSION_MISMATCH`: If `size(A) != (max(I), max(J))` or `size(C) != size(mask)`.
"""
function assign!(
C::AbstractGBMatrix, A::AbstractGBVector, I, J;
C::AbstractGBMatrix, A::GBArrayOrTranspose, I, J;
mask = nothing, accum = nothing, desc = nothing
)
I, ni = idx(I)
Expand All @@ -417,16 +417,16 @@ function assign!(
I = decrement!(I)
J = decrement!(J)
# we know A isn't adjoint/transpose on input
desc = _handledescriptor(desc)
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(A), I, ni, J, nj, desc)
desc = _handledescriptor(desc; in1=A)
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
increment!(I)
increment!(J)
return A
end

function assign!(C::AbstractGBArray, x, I, J;
function assign!(C::AbstractGBArray{T}, x, I, J;
mask = nothing, accum = nothing, desc = nothing
)
) where T
x = typeof(x) === T ? x : convert(T, x)
I, ni = idx(I)
J, nj = idx(J)
Expand Down Expand Up @@ -467,7 +467,7 @@ end
Base.eltype(::Type{AbstractGBVector{T}}) where{T} = T

function Base.deleteat!(v::AbstractGBVector, i)
@wraperror LibGraphBLAS.GrB_Matrix_removeElement(gbpointer(v), decrement!(i), 1)
@wraperror LibGraphBLAS.GrB_Matrix_removeElement(gbpointer(v), decrement!(i), 0)
return v
end

Expand Down Expand Up @@ -520,7 +520,7 @@ for T ∈ valid_vec
I isa Vector || (I = collect(I))
X isa Vector || (X = collect(X))
length(X) == length(I) || DimensionMismatch("I and X must have the same length")
combine = BinaryOp(combine)($T)
combine = binaryop(combine, $T)
decrement!(I)
@wraperror LibGraphBLAS.$func(
Ptr{LibGraphBLAS.GrB_Vector}(gbpointer(v)),
Expand Down Expand Up @@ -606,7 +606,7 @@ function build(v::AbstractGBVector{T}, I::Vector{<:Integer}, X::Vector{T}; combi
I isa Vector || (I = collect(I))
X isa Vector || (X = collect(X))
length(X) == length(I) || DimensionMismatch("I and X must have the same length")
combine = BinaryOp(combine)(T)
combine = binaryop(combine, T)
decrement!(I)
@wraperror LibGraphBLAS.GrB_Matrix_build_UDT(
Ptr{LibGraphBLAS.GrB_Vector}(gbpointer(v)),
Expand Down
3 changes: 0 additions & 3 deletions src/abstracts.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
abstract type AbstractGBType end
abstract type AbstractDescriptor end
abstract type AbstractOp end
abstract type AbstractUnaryOp <: AbstractOp end
abstract type AbstractBinaryOp <: AbstractOp end
abstract type AbstractSelectOp <: AbstractOp end
abstract type AbstractMonoid <: AbstractOp end
abstract type AbstractSemiring <: AbstractOp end
abstract type AbstractTypedOp{Z} end

abstract type AbstractGBArray{T, N, F} <: AbstractSparseArray{T, UInt64, N} end
Expand Down
2 changes: 1 addition & 1 deletion src/chainrules/chainruleutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ using ChainRulesCore
const RealOrComplex = Union{Real, Complex}

# LinearAlgebra.norm doesn't like the nothings.
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
LinearAlgebra.norm(A::GBVecOrMat, p::Real=2) = norm(nonzeros(A), p)
20 changes: 10 additions & 10 deletions src/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
function frule(
(_, ΔA, ΔB, _)::Tuple,
::typeof(emul),
A::GBArray,
B::GBArray,
A::AbstractGBArray,
B::AbstractGBArray,
::typeof(*)
)
Ω = emul(A, B, *)
∂Ω = emul(unthunk(ΔA), B, *) + emul(unthunk(ΔB), A, *)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB)::Tuple, ::typeof(emul), A::GBArray, B::GBArray)
function frule((_, ΔA, ΔB)::Tuple, ::typeof(emul), A::AbstractGBArray, B::AbstractGBArray)
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, *)
end

function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(*))
function rrule(::typeof(emul), A::AbstractGBArray, B::AbstractGBArray, ::typeof(*))
function timespullback(ΔΩ)
∂A = emul(unthunk(ΔΩ), B)
∂B = emul(unthunk(ΔΩ), A)
Expand All @@ -23,7 +23,7 @@ function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(*))
return emul(A, B, *), timespullback
end

function rrule(::typeof(emul), A::GBArray, B::GBArray)
function rrule(::typeof(emul), A::AbstractGBArray, B::AbstractGBArray)
Ω, fullpb = rrule(emul, A, B, *)
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, emulpb
Expand All @@ -39,19 +39,19 @@ end
function frule(
(_, ΔA, ΔB, _)::Tuple,
::typeof(eadd),
A::GBArray,
B::GBArray,
A::AbstractGBArray,
B::AbstractGBArray,
::typeof(+)
)
Ω = eadd(A, B, +)
∂Ω = eadd(unthunk(ΔA), unthunk(ΔB), +)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB)::Tuple, ::typeof(eadd), A::GBArray, B::GBArray)
function frule((_, ΔA, ΔB)::Tuple, ::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray)
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, +)
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(+))
function rrule(::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray, ::typeof(+))
function pluspullback(ΔΩ)
return (
NoTangent(),
Expand All @@ -63,7 +63,7 @@ function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(+))
return eadd(A, B, +), pluspullback
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray)
function rrule(::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray)
Ω, fullpb = rrule(eadd, A, B, +)
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, eaddpb
Expand Down
8 changes: 4 additions & 4 deletions src/chainrules/maprules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ macro scalarapplyrule(func, derivative)
(_, _, $(esc(:ΔA)))::Tuple,
::typeof(apply),
::typeof($(func)),
$(esc(:A))::GBArray
$(esc(:A))::AbstractGBArray
)
$(esc(:Ω)) = apply($(esc(func)), $(esc(:A)))
return $(esc(:Ω)), $(esc(derivative)) .* unthunk($(esc(:ΔA)))
end
function ChainRulesCore.rrule(
::typeof(apply),
::typeof($(func)),
$(esc(:A))::GBArray
$(esc(:A))::AbstractGBArray
)
$(esc(:Ω)) = apply($(esc(func)), $(esc(:A)))
function applyback($(esc(:ΔA)))
Expand Down Expand Up @@ -75,10 +75,10 @@ function frule(
(_, _, ΔA)::Tuple,
::typeof(apply),
::typeof(identity),
A::GBArray
A::AbstractGBArray
)
return (A, ΔA)
end
function rrule(::typeof(apply), ::typeof(identity), A::GBArray)
function rrule(::typeof(apply), ::typeof(identity), A::AbstractGBArray)
return A, (ΔΩ) -> (NoTangent(), NoTangent(), ΔΩ)
end
Loading