Skip to content

Commit 09d22f5

Browse files
author
Will Kimmerer
authored
Type Refresh, fix inference issues, remove with functionality (#30)
* better typing * Fix type inference, remove with/context, test toml
1 parent f0dd5c9 commit 09d22f5

27 files changed

+317
-327
lines changed

Project.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,17 @@ version = "0.4.0"
66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
10-
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
119
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1210
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1311
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1412
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1513
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1614
SSGraphBLAS_jll = "7ed9a814-9cab-54e9-8e9e-d9e95b4d61b1"
1715
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
18-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1916

2017
[compat]
2118
CEnum = "0.4"
22-
ContextVariablesX = "0.1"
19+
ChainRulesCore = "0.10"
2320
MacroTools = "0.5"
2421
SSGraphBLAS_jll = "5.1"
2522
julia = "1.6"
26-
ChainRulesCore = "0.10"
27-
ChainRulesTestUtils = "0.7"
28-
FiniteDifferences = "0.12"

src/SuiteSparseGraphBLAS.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@ using MacroTools
77
using LinearAlgebra
88
using Random: randsubseq, default_rng, AbstractRNG, GLOBAL_RNG
99
using CEnum
10-
using ContextVariablesX
1110
include("abstracts.jl")
1211
include("libutils.jl")
1312
include("lib/LibGraphBLAS.jl")
1413
using .libgb
1514

1615

17-
16+
include("operators/libgbops.jl")
1817
include("types.jl")
1918
include("gbtypes.jl")
2019

@@ -49,10 +48,17 @@ const GrBOp = Union{
4948
libgb.GxB_SelectOp
5049
}
5150

51+
const TypedOp = Union{
52+
TypedUnaryOperator,
53+
TypedBinaryOperator,
54+
TypedMonoid,
55+
TypedSemiring
56+
}
57+
5258
const MonoidBinaryOrRig = Union{
53-
libgb.GrB_Monoid,
54-
libgb.GrB_Semiring,
55-
libgb.GrB_BinaryOp,
59+
TypedMonoid,
60+
TypedSemiring,
61+
TypedBinaryOperator,
5662
AbstractSemiring,
5763
AbstractBinaryOp,
5864
AbstractMonoid
@@ -63,13 +69,6 @@ const OperatorUnion = Union{
6369
GrBOp
6470
}
6571

66-
#Context variables for the `with` function
67-
@contextvar ctxop::OperatorUnion
68-
@contextvar ctxmask::Union{GBArray, Ptr} = C_NULL
69-
@contextvar ctxaccum::Union{BinaryUnion, Ptr} = C_NULL
70-
@contextvar ctxdesc::Descriptor
71-
include("with.jl")
72-
7372
include("scalar.jl")
7473
include("vector.jl")
7574
include("matrix.jl")

src/abstracts.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ abstract type AbstractBinaryOp <: AbstractOp end
66
abstract type AbstractSelectOp <: AbstractOp end
77
abstract type AbstractMonoid <: AbstractOp end
88
abstract type AbstractSemiring <: AbstractOp end
9+
abstract type AbstractTypedOp{Z} end

src/chainrules/mulrules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ function frule(
1414
B::GBMatOrTranspose,
1515
::typeof(Semirings.PLUS_TIMES)
1616
)
17-
Ω = mul(A, B)
18-
∂Ω = mul(ΔA, B) + mul(A, ΔB)
17+
Ω = mul(A, B, Semirings.PLUS_TIMES)
18+
∂Ω = mul(ΔA, B, Semirings.PLUS_TIMES) + mul(A, ΔB, Semirings.PLUS_TIMES)
1919
return Ω, ∂Ω
2020
end
2121
# Tests will not pass for this. For two reasons.
@@ -32,8 +32,8 @@ function rrule(
3232
::typeof(Semirings.PLUS_TIMES)
3333
)
3434
function mulpullback(ΔΩ)
35-
∂A = mul(ΔΩ, B'; mask=A)
36-
∂B = mul(A', ΔΩ; mask=B)
35+
∂A = mul(ΔΩ, B', Semirings.PLUS_TIMES; mask=A)
36+
∂B = mul(A', ΔΩ, Semirings.PLUS_TIMES; mask=B)
3737
return NoTangent(), ∂A, ∂B, NoTangent()
3838
end
3939
return mul(A, B), mulpullback

src/gbtypes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function _load_globaltypes()
7575
ptrtogbtype[FP32.p] = FP32
7676
global FP64 = GBType{Float64}("GrB_FP64")
7777
ptrtogbtype[FP64.p] = FP64
78-
global FC32 = GBType{ComplexF32}("GxB_FC64")
78+
global FC32 = GBType{ComplexF32}("GxB_FC32")
7979
ptrtogbtype[FC32.p] = FC32
8080
global FC64 = GBType{ComplexF32}("GxB_FC64")
8181
ptrtogbtype[FC64.p] = FC64

src/libutils.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,31 +65,21 @@ end
6565

6666

6767
"Load a global constant from SSGrB, optionally specify the resulting pointer type."
68-
function load_global(str, type = Cvoid)
68+
function load_global(str, type::Type{Ptr{T}} = Ptr{Nothing}) where {T}
6969
x =
7070
try
7171
dlsym(SSGraphBLAS_jll.libgraphblas_handle, str)
7272
catch e
7373
@warn "Symbol not available " * str
7474
return C_NULL
7575
end
76-
return unsafe_load(cglobal(x, Ptr{type}))
76+
return unsafe_load(cglobal(x, type))
7777
end
7878

79+
load_global(str, type) = load_global(str, Ptr{type})
80+
7981
isGxB(name) = name[1:3] == "GxB"
8082
isGrB(name) = name[1:3] == "GrB"
81-
"""
82-
_print_unsigned_as_signed()
83-
84-
The SuiteSparseGraphBLAS index, GrB_Index, is an alias for UInt64. Julia prints values of
85-
this type in hex, so this can be used to change the printing method to decimal.
86-
87-
This is not recommended for general use and will likely be removed once better printing is
88-
added to this package.
89-
"""
90-
function _print_unsigned_as_signed()
91-
eval(:(Base.show(io::IO, a::Unsigned) = print(io, Int(a))))
92-
end
9383

9484
function splitconstant(str)
9585
return String.(split(str, "_"))

src/operations/ewise.jl

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,22 @@ function emul!(
5555
w::GBVector,
5656
u::GBVector,
5757
v::GBVector,
58-
op = nothing;
58+
op = BinaryOps.TIMES;
5959
mask = nothing,
6060
accum = nothing,
6161
desc = nothing
6262
)
63-
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
63+
mask, accum, desc = _handlenothings(mask, accum, desc)
6464
size(w) == size(u) == size(v) || throw(DimensionMismatch())
6565
op = getoperator(op, optype(u, v))
6666
accum = getoperator(accum, eltype(w))
67-
if op isa libgb.GrB_Semiring
67+
if op isa TypedSemiring
6868
libgb.GrB_Vector_eWiseMult_Semiring(w, mask, accum, op, u, v, desc)
6969
return w
70-
elseif op isa libgb.GrB_Monoid
70+
elseif op isa TypedMonoid
7171
libgb.GrB_Vector_eWiseMult_Monoid(w, mask, accum, op, u, v, desc)
7272
return w
73-
elseif op isa libgb.GrB_BinaryOp
73+
elseif op isa TypedBinaryOperator
7474
libgb.GrB_Vector_eWiseMult_BinaryOp(w, mask, accum, op, u, v, desc)
7575
return w
7676
else
@@ -82,12 +82,11 @@ end
8282
function emul(
8383
u::GBVector,
8484
v::GBVector,
85-
op = nothing;
85+
op = BinaryOps.TIMES;
8686
mask = nothing,
8787
accum = nothing,
8888
desc = nothing
8989
)
90-
op = _handlectx(op, ctxop, BinaryOps.TIMES)
9190
t = inferoutputtype(u, v, op)
9291
w = GBVector{t}(size(u))
9392
return emul!(w, u, v, op; mask , accum, desc)
@@ -97,23 +96,23 @@ function emul!(
9796
C::GBMatrix,
9897
A::GBMatOrTranspose,
9998
B::GBMatOrTranspose,
100-
op = nothing;
99+
op = BinaryOps.TIMES;
101100
mask = nothing,
102101
accum = nothing,
103102
desc = nothing
104103
)
105-
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
104+
mask, accum, desc = _handlenothings(mask, accum, desc)
106105
size(C) == size(A) == size(B) || throw(DimensionMismatch())
107106
A, desc, B = _handletranspose(A, desc, B)
108107
op = getoperator(op, optype(A, B))
109108
accum = getoperator(accum, eltype(C))
110-
if op isa libgb.GrB_Semiring
109+
if op isa TypedSemiring
111110
libgb.GrB_Matrix_eWiseMult_Semiring(C, mask, accum, op, A, B, desc)
112111
return C
113-
elseif op isa libgb.GrB_Monoid
112+
elseif op isa TypedMonoid
114113
libgb.GrB_Matrix_eWiseMult_Monoid(C, mask, accum, op, A, B, desc)
115114
return C
116-
elseif op isa libgb.GrB_BinaryOp
115+
elseif op isa TypedBinaryOperator
117116
libgb.GrB_Matrix_eWiseMult_BinaryOp(C, mask, accum, op, A, B, desc)
118117
return C
119118
else
@@ -125,12 +124,11 @@ end
125124
function emul(
126125
A::GBMatOrTranspose,
127126
B::GBMatOrTranspose,
128-
op = nothing;
127+
op = BinaryOps.TIMES;
129128
mask = nothing,
130129
accum = nothing,
131130
desc = nothing
132131
)
133-
op = _handlectx(op, ctxop, BinaryOps.TIMES)
134132
t = inferoutputtype(A, B, op)
135133
C = GBMatrix{t}(size(A))
136134
return emul!(C, A, B, op; mask, accum, desc)
@@ -193,22 +191,22 @@ function eadd!(
193191
w::GBVector,
194192
u::GBVector,
195193
v::GBVector,
196-
op = nothing;
194+
op = BinaryOps.PLUS;
197195
mask = nothing,
198196
accum = nothing,
199197
desc = nothing
200198
)
201-
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
199+
mask, accum, desc = _handlenothings(mask, accum, desc)
202200
size(w) == size(u) == size(v) || throw(DimensionMismatch())
203201
op = getoperator(op, optype(u, v))
204202
accum = getoperator(accum, eltype(w))
205-
if op isa libgb.GrB_Semiring
203+
if op isa TypedSemiring
206204
libgb.GrB_Vector_eWiseAdd_Semiring(w, mask, accum, op, u, v, desc)
207205
return w
208-
elseif op isa libgb.GrB_Monoid
206+
elseif op isa TypedMonoid
209207
libgb.GrB_Vector_eWiseAdd_Monoid(w, mask, accum, op, u, v, desc)
210208
return w
211-
elseif op isa libgb.GrB_BinaryOp
209+
elseif op isa TypedBinaryOperator
212210
libgb.GrB_Vector_eWiseAdd_BinaryOp(w, mask, accum, op, u, v, desc)
213211
return w
214212
else
@@ -220,12 +218,11 @@ end
220218
function eadd(
221219
u::GBVector,
222220
v::GBVector,
223-
op = nothing;
221+
op = BinaryOps.PLUS;
224222
mask = nothing,
225223
accum = nothing,
226224
desc = nothing
227225
)
228-
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
229226
t = inferoutputtype(u, v, op)
230227
w = GBVector{t}(size(u))
231228
return eadd!(w, u, v, op; mask, accum, desc)
@@ -235,23 +232,23 @@ function eadd!(
235232
C::GBMatrix,
236233
A::GBMatOrTranspose,
237234
B::GBMatOrTranspose,
238-
op = nothing;
235+
op = BinaryOps.PLUS;
239236
mask = nothing,
240237
accum = nothing,
241238
desc = nothing
242239
)
243-
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
240+
mask, accum, desc = _handlenothings(mask, accum, desc)
244241
size(C) == size(A) == size(B) || throw(DimensionMismatch())
245242
A, desc, B = _handletranspose(A, desc, B)
246243
op = getoperator(op, optype(A, B))
247244
accum = getoperator(accum, eltype(C))
248-
if op isa libgb.GrB_Semiring
245+
if op isa TypedSemiring
249246
libgb.GrB_Matrix_eWiseAdd_Semiring(C, mask, accum, op, A, B, desc)
250247
return C
251-
elseif op isa libgb.GrB_Monoid
248+
elseif op isa TypedMonoid
252249
libgb.GrB_Matrix_eWiseAdd_Monoid(C, mask, accum, op, A, B, desc)
253250
return C
254-
elseif op isa libgb.GrB_BinaryOp
251+
elseif op isa TypedBinaryOperator
255252
libgb.GrB_Matrix_eWiseAdd_BinaryOp(C, mask, accum, op, A, B, desc)
256253
return C
257254
else
@@ -263,19 +260,18 @@ end
263260
function eadd(
264261
A::GBMatOrTranspose,
265262
B::GBMatOrTranspose,
266-
op = nothing;
263+
op = BinaryOps.PLUS;
267264
mask = nothing,
268265
accum = nothing,
269266
desc = nothing
270267
)
271-
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
272268
t = inferoutputtype(A, B, op)
273269
C = GBMatrix{t}(size(A))
274270
return eadd!(C, A, B, op; mask, accum, desc)
275271
end
276272

277273
function Base.:+(A::GBArray, B::GBArray)
278-
eadd(A, B, nothing)
274+
eadd(A, B, BinaryOps.PLUS)
279275
end
280276

281277
function Base.:-(A::GBArray, B::GBArray)
@@ -284,16 +280,6 @@ end
284280
#Elementwise Broadcasts
285281
#######################
286282

287-
# default argument is missing to avoid `nothing` picking up the default default :).
288-
function Base.broadcasted(::typeof(), A::GBArray, B::GBArray)
289-
eadd(A, B, missing)
290-
end
291-
292-
# default argument is missing to avoid `nothing` picking up the default default :).
293-
function Base.broadcasted(::typeof(), A::GBArray, B::GBArray)
294-
emul(A, B, missing)
295-
end
296-
297283
function Base.broadcasted(::typeof(*), A::GBArray, B::GBArray)
298284
emul(A, B, BinaryOps.TIMES)
299285
end

src/operations/kronecker.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@ function LinearAlgebra.kron!(
77
C::GBMatOrTranspose,
88
A::GBMatOrTranspose,
99
B::GBMatOrTranspose,
10-
op = nothing;
10+
op = BinaryOps.TIMES;
1111
mask = nothing,
1212
accum = nothing,
1313
desc = nothing
1414
)
15-
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
15+
mask, accum, desc = _handlenothings(mask, accum, desc)
1616
op = getoperator(op, optype(A, B))
1717
A, desc, B = _handletranspose(A, desc, B)
1818
accum = getoperator(accum, eltype(C))
19-
if op isa libgb.GrB_BinaryOp
19+
if op isa TypedBinaryOperator
2020
libgb.GxB_kron(C, mask, accum, op, A, B, desc)
21-
elseif op isa libgb.GrB_Monoid
21+
elseif op isa TypedMonoid
2222
libgb.GrB_Matrix_kronecker_Monoid(C, mask, accum, op, A, B, desc)
23-
elseif op isa libgb.GrB_Semiring
23+
elseif op isa TypedSemiring
2424
libgb.GrB_Matrix_kronecker_Semiring(C, mask, accum, op, A, B, desc)
2525
else
2626
throw(ArgumentError("$op is not a valid monoid binary op or semiring."))
@@ -47,12 +47,11 @@ Does not support `GBVector`s at this time.
4747
function LinearAlgebra.kron(
4848
A::GBMatOrTranspose,
4949
B::GBMatOrTranspose,
50-
op = nothing;
50+
op = BinaryOps.TIMES;
5151
mask = nothing,
5252
accum = nothing,
5353
desc = nothing
5454
)
55-
op = _handlectx(op, ctxop, BinaryOps.TIMES)
5655
t = inferoutputtype(A, B, op)
5756
C = GBMatrix{t}(size(A,1) * size(B, 1), size(A, 2) * size(B, 2))
5857
kron!(C, A, B, op; mask, accum, desc)

0 commit comments

Comments
 (0)