Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,17 @@ version = "0.4.0"
[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SSGraphBLAS_jll = "7ed9a814-9cab-54e9-8e9e-d9e95b4d61b1"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CEnum = "0.4"
ContextVariablesX = "0.1"
ChainRulesCore = "0.10"
MacroTools = "0.5"
SSGraphBLAS_jll = "5.1"
julia = "1.6"
ChainRulesCore = "0.10"
ChainRulesTestUtils = "0.7"
FiniteDifferences = "0.12"
23 changes: 11 additions & 12 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ using MacroTools
using LinearAlgebra
using Random: randsubseq, default_rng, AbstractRNG, GLOBAL_RNG
using CEnum
using ContextVariablesX
include("abstracts.jl")
include("libutils.jl")
include("lib/LibGraphBLAS.jl")
using .libgb



include("operators/libgbops.jl")
include("types.jl")
include("gbtypes.jl")

Expand Down Expand Up @@ -49,10 +48,17 @@ const GrBOp = Union{
libgb.GxB_SelectOp
}

const TypedOp = Union{
TypedUnaryOperator,
TypedBinaryOperator,
TypedMonoid,
TypedSemiring
}

const MonoidBinaryOrRig = Union{
libgb.GrB_Monoid,
libgb.GrB_Semiring,
libgb.GrB_BinaryOp,
TypedMonoid,
TypedSemiring,
TypedBinaryOperator,
AbstractSemiring,
AbstractBinaryOp,
AbstractMonoid
Expand All @@ -63,13 +69,6 @@ const OperatorUnion = Union{
GrBOp
}

#Context variables for the `with` function
@contextvar ctxop::OperatorUnion
@contextvar ctxmask::Union{GBArray, Ptr} = C_NULL
@contextvar ctxaccum::Union{BinaryUnion, Ptr} = C_NULL
@contextvar ctxdesc::Descriptor
include("with.jl")

include("scalar.jl")
include("vector.jl")
include("matrix.jl")
Expand Down
1 change: 1 addition & 0 deletions src/abstracts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ abstract type AbstractBinaryOp <: AbstractOp end
abstract type AbstractSelectOp <: AbstractOp end
abstract type AbstractMonoid <: AbstractOp end
abstract type AbstractSemiring <: AbstractOp end
abstract type AbstractTypedOp{Z} end
8 changes: 4 additions & 4 deletions src/chainrules/mulrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ function frule(
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
Ω = mul(A, B)
∂Ω = mul(ΔA, B) + mul(A, ΔB)
Ω = mul(A, B, Semirings.PLUS_TIMES)
∂Ω = mul(ΔA, B, Semirings.PLUS_TIMES) + mul(A, ΔB, Semirings.PLUS_TIMES)
return Ω, ∂Ω
end
# Tests will not pass for this. For two reasons.
Expand All @@ -32,8 +32,8 @@ function rrule(
::typeof(Semirings.PLUS_TIMES)
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B'; mask=A)
∂B = mul(A', ΔΩ; mask=B)
∂A = mul(ΔΩ, B', Semirings.PLUS_TIMES; mask=A)
∂B = mul(A', ΔΩ, Semirings.PLUS_TIMES; mask=B)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return mul(A, B), mulpullback
Expand Down
2 changes: 1 addition & 1 deletion src/gbtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function _load_globaltypes()
ptrtogbtype[FP32.p] = FP32
global FP64 = GBType{Float64}("GrB_FP64")
ptrtogbtype[FP64.p] = FP64
global FC32 = GBType{ComplexF32}("GxB_FC64")
global FC32 = GBType{ComplexF32}("GxB_FC32")
ptrtogbtype[FC32.p] = FC32
global FC64 = GBType{ComplexF32}("GxB_FC64")
ptrtogbtype[FC64.p] = FC64
Expand Down
18 changes: 4 additions & 14 deletions src/libutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,31 +65,21 @@ end


"Load a global constant from SSGrB, optionally specify the resulting pointer type."
function load_global(str, type = Cvoid)
function load_global(str, type::Type{Ptr{T}} = Ptr{Nothing}) where {T}
x =
try
dlsym(SSGraphBLAS_jll.libgraphblas_handle, str)
catch e
@warn "Symbol not available " * str
return C_NULL
end
return unsafe_load(cglobal(x, Ptr{type}))
return unsafe_load(cglobal(x, type))
end

load_global(str, type) = load_global(str, Ptr{type})

isGxB(name) = name[1:3] == "GxB"
isGrB(name) = name[1:3] == "GrB"
"""
_print_unsigned_as_signed()

The SuiteSparseGraphBLAS index, GrB_Index, is an alias for UInt64. Julia prints values of
this type in hex, so this can be used to change the printing method to decimal.

This is not recommended for general use and will likely be removed once better printing is
added to this package.
"""
function _print_unsigned_as_signed()
eval(:(Base.show(io::IO, a::Unsigned) = print(io, Int(a))))
end

function splitconstant(str)
return String.(split(str, "_"))
Expand Down
64 changes: 25 additions & 39 deletions src/operations/ewise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ function emul!(
w::GBVector,
u::GBVector,
v::GBVector,
op = nothing;
op = BinaryOps.TIMES;
mask = nothing,
accum = nothing,
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
mask, accum, desc = _handlenothings(mask, accum, desc)
size(w) == size(u) == size(v) || throw(DimensionMismatch())
op = getoperator(op, optype(u, v))
accum = getoperator(accum, eltype(w))
if op isa libgb.GrB_Semiring
if op isa TypedSemiring
libgb.GrB_Vector_eWiseMult_Semiring(w, mask, accum, op, u, v, desc)
return w
elseif op isa libgb.GrB_Monoid
elseif op isa TypedMonoid
libgb.GrB_Vector_eWiseMult_Monoid(w, mask, accum, op, u, v, desc)
return w
elseif op isa libgb.GrB_BinaryOp
elseif op isa TypedBinaryOperator
libgb.GrB_Vector_eWiseMult_BinaryOp(w, mask, accum, op, u, v, desc)
return w
else
Expand All @@ -82,12 +82,11 @@ end
function emul(
u::GBVector,
v::GBVector,
op = nothing;
op = BinaryOps.TIMES;
mask = nothing,
accum = nothing,
desc = nothing
)
op = _handlectx(op, ctxop, BinaryOps.TIMES)
t = inferoutputtype(u, v, op)
w = GBVector{t}(size(u))
return emul!(w, u, v, op; mask , accum, desc)
Expand All @@ -97,23 +96,23 @@ function emul!(
C::GBMatrix,
A::GBMatOrTranspose,
B::GBMatOrTranspose,
op = nothing;
op = BinaryOps.TIMES;
mask = nothing,
accum = nothing,
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
mask, accum, desc = _handlenothings(mask, accum, desc)
size(C) == size(A) == size(B) || throw(DimensionMismatch())
A, desc, B = _handletranspose(A, desc, B)
op = getoperator(op, optype(A, B))
accum = getoperator(accum, eltype(C))
if op isa libgb.GrB_Semiring
if op isa TypedSemiring
libgb.GrB_Matrix_eWiseMult_Semiring(C, mask, accum, op, A, B, desc)
return C
elseif op isa libgb.GrB_Monoid
elseif op isa TypedMonoid
libgb.GrB_Matrix_eWiseMult_Monoid(C, mask, accum, op, A, B, desc)
return C
elseif op isa libgb.GrB_BinaryOp
elseif op isa TypedBinaryOperator
libgb.GrB_Matrix_eWiseMult_BinaryOp(C, mask, accum, op, A, B, desc)
return C
else
Expand All @@ -125,12 +124,11 @@ end
function emul(
A::GBMatOrTranspose,
B::GBMatOrTranspose,
op = nothing;
op = BinaryOps.TIMES;
mask = nothing,
accum = nothing,
desc = nothing
)
op = _handlectx(op, ctxop, BinaryOps.TIMES)
t = inferoutputtype(A, B, op)
C = GBMatrix{t}(size(A))
return emul!(C, A, B, op; mask, accum, desc)
Expand Down Expand Up @@ -193,22 +191,22 @@ function eadd!(
w::GBVector,
u::GBVector,
v::GBVector,
op = nothing;
op = BinaryOps.PLUS;
mask = nothing,
accum = nothing,
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
mask, accum, desc = _handlenothings(mask, accum, desc)
size(w) == size(u) == size(v) || throw(DimensionMismatch())
op = getoperator(op, optype(u, v))
accum = getoperator(accum, eltype(w))
if op isa libgb.GrB_Semiring
if op isa TypedSemiring
libgb.GrB_Vector_eWiseAdd_Semiring(w, mask, accum, op, u, v, desc)
return w
elseif op isa libgb.GrB_Monoid
elseif op isa TypedMonoid
libgb.GrB_Vector_eWiseAdd_Monoid(w, mask, accum, op, u, v, desc)
return w
elseif op isa libgb.GrB_BinaryOp
elseif op isa TypedBinaryOperator
libgb.GrB_Vector_eWiseAdd_BinaryOp(w, mask, accum, op, u, v, desc)
return w
else
Expand All @@ -220,12 +218,11 @@ end
function eadd(
u::GBVector,
v::GBVector,
op = nothing;
op = BinaryOps.PLUS;
mask = nothing,
accum = nothing,
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
t = inferoutputtype(u, v, op)
w = GBVector{t}(size(u))
return eadd!(w, u, v, op; mask, accum, desc)
Expand All @@ -235,23 +232,23 @@ function eadd!(
C::GBMatrix,
A::GBMatOrTranspose,
B::GBMatOrTranspose,
op = nothing;
op = BinaryOps.PLUS;
mask = nothing,
accum = nothing,
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
mask, accum, desc = _handlenothings(mask, accum, desc)
size(C) == size(A) == size(B) || throw(DimensionMismatch())
A, desc, B = _handletranspose(A, desc, B)
op = getoperator(op, optype(A, B))
accum = getoperator(accum, eltype(C))
if op isa libgb.GrB_Semiring
if op isa TypedSemiring
libgb.GrB_Matrix_eWiseAdd_Semiring(C, mask, accum, op, A, B, desc)
return C
elseif op isa libgb.GrB_Monoid
elseif op isa TypedMonoid
libgb.GrB_Matrix_eWiseAdd_Monoid(C, mask, accum, op, A, B, desc)
return C
elseif op isa libgb.GrB_BinaryOp
elseif op isa TypedBinaryOperator
libgb.GrB_Matrix_eWiseAdd_BinaryOp(C, mask, accum, op, A, B, desc)
return C
else
Expand All @@ -263,19 +260,18 @@ end
function eadd(
A::GBMatOrTranspose,
B::GBMatOrTranspose,
op = nothing;
op = BinaryOps.PLUS;
mask = nothing,
accum = nothing,
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
t = inferoutputtype(A, B, op)
C = GBMatrix{t}(size(A))
return eadd!(C, A, B, op; mask, accum, desc)
end

function Base.:+(A::GBArray, B::GBArray)
eadd(A, B, nothing)
eadd(A, B, BinaryOps.PLUS)
end

function Base.:-(A::GBArray, B::GBArray)
Expand All @@ -284,16 +280,6 @@ end
#Elementwise Broadcasts
#######################

# default argument is missing to avoid `nothing` picking up the default default :).
function Base.broadcasted(::typeof(∪), A::GBArray, B::GBArray)
eadd(A, B, missing)
end

# default argument is missing to avoid `nothing` picking up the default default :).
function Base.broadcasted(::typeof(∩), A::GBArray, B::GBArray)
emul(A, B, missing)
end

function Base.broadcasted(::typeof(*), A::GBArray, B::GBArray)
emul(A, B, BinaryOps.TIMES)
end
Expand Down
13 changes: 6 additions & 7 deletions src/operations/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ function LinearAlgebra.kron!(
C::GBMatOrTranspose,
A::GBMatOrTranspose,
B::GBMatOrTranspose,
op = nothing;
op = BinaryOps.TIMES;
mask = nothing,
accum = nothing,
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
mask, accum, desc = _handlenothings(mask, accum, desc)
op = getoperator(op, optype(A, B))
A, desc, B = _handletranspose(A, desc, B)
accum = getoperator(accum, eltype(C))
if op isa libgb.GrB_BinaryOp
if op isa TypedBinaryOperator
libgb.GxB_kron(C, mask, accum, op, A, B, desc)
elseif op isa libgb.GrB_Monoid
elseif op isa TypedMonoid
libgb.GrB_Matrix_kronecker_Monoid(C, mask, accum, op, A, B, desc)
elseif op isa libgb.GrB_Semiring
elseif op isa TypedSemiring
libgb.GrB_Matrix_kronecker_Semiring(C, mask, accum, op, A, B, desc)
else
throw(ArgumentError("$op is not a valid monoid binary op or semiring."))
Expand All @@ -47,12 +47,11 @@ Does not support `GBVector`s at this time.
function LinearAlgebra.kron(
A::GBMatOrTranspose,
B::GBMatOrTranspose,
op = nothing;
op = BinaryOps.TIMES;
mask = nothing,
accum = nothing,
desc = nothing
)
op = _handlectx(op, ctxop, BinaryOps.TIMES)
t = inferoutputtype(A, B, op)
C = GBMatrix{t}(size(A,1) * size(B, 1), size(A, 2) * size(B, 2))
kron!(C, A, B, op; mask, accum, desc)
Expand Down
Loading