Skip to content

Commit ba33380

Browse files
author
Will Kimmerer
authored
UDT/Monoid/Bugfix Update (#86)
* UnaryOp rework, remove need to @Unop * Removal of BinaryOp, Monoid update in progress * monoids working on the surface * passing all except random map tests * fix promotion * rm debug prints * rename * fix #85, fix #83 * fix #82 * work towards #81 * fixes #80, fixes #76 * fix #77
1 parent 7aeaf5b commit ba33380

39 files changed

+613
-431
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "0.7.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8-
HyperSparseMatrices = "c7efdb1c-7caa-4c7d-9b5e-9093f9323c7c"
98
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
109
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -21,7 +20,6 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2120

2221
[compat]
2322
ChainRulesCore = "1"
24-
HyperSparseMatrices = "0.2"
2523
MacroTools = "0.5"
2624
Preferences = "1"
2725
SSGraphBLAS_jll = "6.2.1"

docs/src/binaryops.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ Internally functions are lowered like this:
2727
```@repl
2828
using SuiteSparseGraphBLAS
2929
30-
op = BinaryOp(+)
31-
32-
typedop = op(Int64, Int64)
30+
typedop = binaryop(+, Int64, Int64)
3331
3432
eadd(GBVector([1,2]), GBVector([3,4]), typedop)
3533
```

docs/src/operators.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ Operators are lowered from a Julia function to a container like `BinaryOp` or `S
5151
using SuiteSparseGraphBLAS
5252
```
5353
```@repl operators
54-
b = BinaryOp(+)
55-
b(Int32)
54+
b = binaryop(+, Int32)
5655
5756
s = Semiring(max, +)
5857
s(Float64)

docs/src/udfs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ GraphBLAS supports users to supply functions as operators. Constructors exported
77

88
- `UnaryOp(name::String, fn::Function, [type | types | ztype, xtype | ztypes, xtypes])`
99
- `BinaryOp(name::String, fn::Function, [type | types | ztype, xtype | ztypes, xtypes])`
10-
- `Monoid(name::String, binop::Union{AbstractBinaryOp, GrB_BinaryOp}, id::T, terminal::T = nothing)`: all types must be the same.
11-
- `Semiring(name::String, add::[GrB_Monoid | AbstractMonoid], mul::[GrB_BinaryOp | AbstractBinaryOp])`
10+
- `Monoid(name::String, binop::Union{GrB_BinaryOp}, id::T, terminal::T = nothing)`: all types must be the same.
11+
- `Semiring(name::String, add::[GrB_Monoid | AbstractMonoid], mul::GrB_BinaryOp)`
1212

1313
`GrB_` prefixed arguments are typed operators, such as the result of `UnaryOps.COS[Float64]`.
1414
Type arguments may be single types or vectors of types.

docs/src/unaryops.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ Internally functions are lowered like this:
1919
```@repl
2020
using SuiteSparseGraphBLAS
2121
22-
op = UnaryOp(sin)
23-
24-
typedop = op(Float64)
22+
op = unaryop(sin, Float64)
2523
2624
map(typedop, GBVector([1.5, 0, pi]))
2725
```

src/SuiteSparseGraphBLAS.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ using Serialization
2525
using StorageOrders
2626

2727
export ColMajor, RowMajor, storageorder #reexports from StorageOrders
28-
29-
using HyperSparseMatrices
3028
include("abstracts.jl")
3129
include("libutils.jl")
3230

@@ -101,7 +99,7 @@ include("oriented.jl")
10199
export SparseArrayCompat
102100
export LibGraphBLAS
103101
# export UnaryOps, BinaryOps, Monoids, Semirings #Submodules
104-
export UnaryOp, BinaryOp, Monoid, Semiring #UDFs
102+
export unaryop, binaryop, Monoid, semiring #UDFs
105103
export Descriptor #Types
106104
export gbset, gbget # global and object specific options.
107105
# export xtype, ytype, ztype #Determine input/output types of operators

src/abstractgbarray.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# AbstractGBArray functions:
2-
function SparseArrays.nnz(A::AbsGBArrayOrTranspose)
2+
function SparseArrays.nnz(A::GBArrayOrTranspose)
33
nvals = Ref{LibGraphBLAS.GrB_Index}()
44
@wraperror LibGraphBLAS.GrB_Matrix_nvals(nvals, gbpointer(parent(A)))
55
return Int64(nvals[])
66
end
77

8-
Base.eltype(::Type{AbstractGBArray{T}}) where{T} = T
8+
Base.eltype(::Type{GBArrayOrTranspose{T}}) where{T} = T
99

1010
"""
1111
empty!(v::GBVector)
@@ -14,32 +14,32 @@ Base.eltype(::Type{AbstractGBArray{T}}) where{T} = T
1414
Clear all the entries from the GBArray.
1515
Does not modify the type or dimensions.
1616
"""
17-
function Base.empty!(A::AbsGBArrayOrTranspose)
17+
function Base.empty!(A::GBArrayOrTranspose)
1818
@wraperror LibGraphBLAS.GrB_Matrix_clear(gbpointer(parent(A)))
1919
return A
2020
end
2121

22-
function Base.Matrix(A::AbstractGBMatrix)
22+
function Base.Matrix(A::GBArrayOrTranspose)
2323
sparsity = sparsitystatus(A)
2424
T = copy(A) # We copy here to 1. avoid densifying A, and 2. to avoid destroying A.
2525
return unpack!(T, Dense())
2626
end
2727

28-
function Base.Vector(v::AbstractGBVector)
28+
function Base.Vector(v::GBVectorOrTranspose)
2929
sparsity = sparsitystatus(v)
3030
T = copy(v) # avoid densifying v and destroying v.
3131
return unpack!(T, Dense())
3232
end
3333

34-
function SparseArrays.SparseMatrixCSC(A::AbstractGBArray)
34+
function SparseArrays.SparseMatrixCSC(A::GBArrayOrTranspose)
3535
sparsity = sparsitystatus(A)
3636
T = copy(A) # avoid changing sparsity of A and destroying it.
3737
return unpack!(T, SparseMatrixCSC)
3838
end
3939

40-
function SparseArrays.SparseVector(v::AbstractGBVector)
40+
function SparseArrays.SparseVector(v::GBVectorOrTranspose)
4141
sparsity = sparsitystatus(v)
42-
T = copy(A) # avoid changing sparsity of v and destroying it.
42+
T = copy(v) # avoid changing sparsity of v and destroying it.
4343
return unpack!(T, SparseVector)
4444
end
4545

@@ -94,7 +94,7 @@ for T ∈ valid_vec
9494
function build(A::AbstractGBMatrix{$T}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector{$T};
9595
combine = +
9696
)
97-
combine = BinaryOp(combine)($T)
97+
combine = binaryop(combine, $T)
9898
I isa Vector || (I = collect(I))
9999
J isa Vector || (J = collect(J))
100100
X isa Vector || (X = collect(X))
@@ -181,7 +181,7 @@ function build(
181181
A::AbstractGBMatrix{T}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector{T};
182182
combine = +
183183
) where {T}
184-
combine = BinaryOp(combine)(T)
184+
combine = binaryop(combine, T)
185185
I isa Vector || (I = collect(I))
186186
J isa Vector || (J = collect(J))
187187
X isa Vector || (X = collect(X))
@@ -314,7 +314,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`assign!`](@ref) except that
314314
# Keywords
315315
- `mask::Union{Nothing, GBMatrix} = nothing`: mask where
316316
`size(M) == size(A)`.
317-
- `accum::Union{Nothing, Function, AbstractBinaryOp} = nothing`: binary accumulator operation
317+
- `accum::Union{Nothing, Function} = nothing`: binary accumulator operation
318318
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
319319
- `desc::Union{Nothing, Descriptor} = nothing`
320320
@@ -325,7 +325,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`assign!`](@ref) except that
325325
- `GrB_DIMENSION_MISMATCH`: If `size(A) != (max(I), max(J))` or `size(A) != size(mask)`.
326326
"""
327327
function subassign!(
328-
C::AbstractGBArray, A::AbstractGBArray, I, J;
328+
C::AbstractGBArray, A::GBArrayOrTranspose, I, J;
329329
mask = nothing, accum = nothing, desc = nothing
330330
)
331331
I, ni = idx(I)
@@ -397,7 +397,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`subassign!`](@ref) except that
397397
# Keywords
398398
- `mask::Union{Nothing, GBMatrix} = nothing`: mask where
399399
`size(M) == size(C)`.
400-
- `accum::Union{Nothing, Function, AbstractBinaryOp} = nothing`: binary accumulator operation
400+
- `accum::Union{Nothing, Function} = nothing`: binary accumulator operation
401401
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
402402
- `desc::Union{Nothing, Descriptor} = nothing`
403403
@@ -408,7 +408,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`subassign!`](@ref) except that
408408
- `GrB_DIMENSION_MISMATCH`: If `size(A) != (max(I), max(J))` or `size(C) != size(mask)`.
409409
"""
410410
function assign!(
411-
C::AbstractGBMatrix, A::AbstractGBVector, I, J;
411+
C::AbstractGBMatrix, A::GBArrayOrTranspose, I, J;
412412
mask = nothing, accum = nothing, desc = nothing
413413
)
414414
I, ni = idx(I)
@@ -417,16 +417,16 @@ function assign!(
417417
I = decrement!(I)
418418
J = decrement!(J)
419419
# we know A isn't adjoint/transpose on input
420-
desc = _handledescriptor(desc)
421-
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(A), I, ni, J, nj, desc)
420+
desc = _handledescriptor(desc; in1=A)
421+
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
422422
increment!(I)
423423
increment!(J)
424424
return A
425425
end
426426

427-
function assign!(C::AbstractGBArray, x, I, J;
427+
function assign!(C::AbstractGBArray{T}, x, I, J;
428428
mask = nothing, accum = nothing, desc = nothing
429-
)
429+
) where T
430430
x = typeof(x) === T ? x : convert(T, x)
431431
I, ni = idx(I)
432432
J, nj = idx(J)
@@ -467,7 +467,7 @@ end
467467
Base.eltype(::Type{AbstractGBVector{T}}) where{T} = T
468468

469469
function Base.deleteat!(v::AbstractGBVector, i)
470-
@wraperror LibGraphBLAS.GrB_Matrix_removeElement(gbpointer(v), decrement!(i), 1)
470+
@wraperror LibGraphBLAS.GrB_Matrix_removeElement(gbpointer(v), decrement!(i), 0)
471471
return v
472472
end
473473

@@ -520,7 +520,7 @@ for T ∈ valid_vec
520520
I isa Vector || (I = collect(I))
521521
X isa Vector || (X = collect(X))
522522
length(X) == length(I) || DimensionMismatch("I and X must have the same length")
523-
combine = BinaryOp(combine)($T)
523+
combine = binaryop(combine, $T)
524524
decrement!(I)
525525
@wraperror LibGraphBLAS.$func(
526526
Ptr{LibGraphBLAS.GrB_Vector}(gbpointer(v)),
@@ -606,7 +606,7 @@ function build(v::AbstractGBVector{T}, I::Vector{<:Integer}, X::Vector{T}; combi
606606
I isa Vector || (I = collect(I))
607607
X isa Vector || (X = collect(X))
608608
length(X) == length(I) || DimensionMismatch("I and X must have the same length")
609-
combine = BinaryOp(combine)(T)
609+
combine = binaryop(combine, T)
610610
decrement!(I)
611611
@wraperror LibGraphBLAS.GrB_Matrix_build_UDT(
612612
Ptr{LibGraphBLAS.GrB_Vector}(gbpointer(v)),

src/abstracts.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
abstract type AbstractGBType end
22
abstract type AbstractDescriptor end
33
abstract type AbstractOp end
4-
abstract type AbstractUnaryOp <: AbstractOp end
5-
abstract type AbstractBinaryOp <: AbstractOp end
64
abstract type AbstractSelectOp <: AbstractOp end
75
abstract type AbstractMonoid <: AbstractOp end
8-
abstract type AbstractSemiring <: AbstractOp end
96
abstract type AbstractTypedOp{Z} end
107

118
abstract type AbstractGBArray{T, N, F} <: AbstractSparseArray{T, UInt64, N} end

src/chainrules/chainruleutils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ using ChainRulesCore
44
const RealOrComplex = Union{Real, Complex}
55

66
# LinearAlgebra.norm doesn't like the nothings.
7-
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
7+
LinearAlgebra.norm(A::GBVecOrMat, p::Real=2) = norm(nonzeros(A), p)

src/chainrules/ewiserules.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22
function frule(
33
(_, ΔA, ΔB, _)::Tuple,
44
::typeof(emul),
5-
A::GBArray,
6-
B::GBArray,
5+
A::AbstractGBArray,
6+
B::AbstractGBArray,
77
::typeof(*)
88
)
99
Ω = emul(A, B, *)
1010
∂Ω = emul(unthunk(ΔA), B, *) + emul(unthunk(ΔB), A, *)
1111
return Ω, ∂Ω
1212
end
13-
function frule((_, ΔA, ΔB)::Tuple, ::typeof(emul), A::GBArray, B::GBArray)
13+
function frule((_, ΔA, ΔB)::Tuple, ::typeof(emul), A::AbstractGBArray, B::AbstractGBArray)
1414
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, *)
1515
end
1616

17-
function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(*))
17+
function rrule(::typeof(emul), A::AbstractGBArray, B::AbstractGBArray, ::typeof(*))
1818
function timespullback(ΔΩ)
1919
∂A = emul(unthunk(ΔΩ), B)
2020
∂B = emul(unthunk(ΔΩ), A)
@@ -23,7 +23,7 @@ function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(*))
2323
return emul(A, B, *), timespullback
2424
end
2525

26-
function rrule(::typeof(emul), A::GBArray, B::GBArray)
26+
function rrule(::typeof(emul), A::AbstractGBArray, B::AbstractGBArray)
2727
Ω, fullpb = rrule(emul, A, B, *)
2828
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3]
2929
return Ω, emulpb
@@ -39,19 +39,19 @@ end
3939
function frule(
4040
(_, ΔA, ΔB, _)::Tuple,
4141
::typeof(eadd),
42-
A::GBArray,
43-
B::GBArray,
42+
A::AbstractGBArray,
43+
B::AbstractGBArray,
4444
::typeof(+)
4545
)
4646
Ω = eadd(A, B, +)
4747
∂Ω = eadd(unthunk(ΔA), unthunk(ΔB), +)
4848
return Ω, ∂Ω
4949
end
50-
function frule((_, ΔA, ΔB)::Tuple, ::typeof(eadd), A::GBArray, B::GBArray)
50+
function frule((_, ΔA, ΔB)::Tuple, ::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray)
5151
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, +)
5252
end
5353

54-
function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(+))
54+
function rrule(::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray, ::typeof(+))
5555
function pluspullback(ΔΩ)
5656
return (
5757
NoTangent(),
@@ -63,7 +63,7 @@ function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(+))
6363
return eadd(A, B, +), pluspullback
6464
end
6565

66-
function rrule(::typeof(eadd), A::GBArray, B::GBArray)
66+
function rrule(::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray)
6767
Ω, fullpb = rrule(eadd, A, B, +)
6868
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3]
6969
return Ω, eaddpb

0 commit comments

Comments
 (0)