From 0448185afb36cf0664451947f9c6a582be1cdc2e Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sun, 11 Jul 2021 13:51:19 -0400 Subject: [PATCH 1/3] better (maybe) typing --- src/SuiteSparseGraphBLAS.jl | 15 +++++-- src/abstracts.jl | 1 + src/gbtypes.jl | 2 +- src/libutils.jl | 18 ++------ src/operations/ewise.jl | 24 +++++------ src/operations/kronecker.jl | 6 +-- src/operations/mul.jl | 6 +-- src/operations/operationutils.jl | 12 +++--- src/operators/binaryops.jl | 59 ++++++++++++------------- src/operators/libgbops.jl | 25 +++++++++++ src/operators/monoids.jl | 74 ++++++++++++++++---------------- src/operators/operatorutils.jl | 10 ++--- src/operators/semirings.jl | 65 +++++++++++++--------------- src/operators/unaryops.jl | 55 ++++++++++++------------ src/types.jl | 34 +++++++++++++++ test/chainrules/mulrules.jl | 16 +++---- 16 files changed, 235 insertions(+), 187 deletions(-) create mode 100644 src/operators/libgbops.jl diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index c3800558..40ed763a 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -14,7 +14,7 @@ include("lib/LibGraphBLAS.jl") using .libgb - +include("operators/libgbops.jl") include("types.jl") include("gbtypes.jl") @@ -49,10 +49,17 @@ const GrBOp = Union{ libgb.GxB_SelectOp } +const TypedOp = Union{ + TypedUnaryOperator, + TypedBinaryOperator, + TypedMonoid, + TypedSemiring +} + const MonoidBinaryOrRig = Union{ - libgb.GrB_Monoid, - libgb.GrB_Semiring, - libgb.GrB_BinaryOp, + TypedMonoid, + TypedSemiring, + TypedBinaryOperator, AbstractSemiring, AbstractBinaryOp, AbstractMonoid diff --git a/src/abstracts.jl b/src/abstracts.jl index 047dfb2d..566a1b4d 100644 --- a/src/abstracts.jl +++ b/src/abstracts.jl @@ -6,3 +6,4 @@ abstract type AbstractBinaryOp <: AbstractOp end abstract type AbstractSelectOp <: AbstractOp end abstract type AbstractMonoid <: AbstractOp end abstract type AbstractSemiring <: AbstractOp end +abstract type AbstractTypedOp{Z} end diff --git a/src/gbtypes.jl b/src/gbtypes.jl index abd4a2f8..37885935 100644 --- a/src/gbtypes.jl +++ b/src/gbtypes.jl @@ -75,7 +75,7 @@ function _load_globaltypes() ptrtogbtype[FP32.p] = FP32 global FP64 = GBType{Float64}("GrB_FP64") ptrtogbtype[FP64.p] = FP64 - global FC32 = GBType{ComplexF32}("GxB_FC64") + global FC32 = GBType{ComplexF32}("GxB_FC32") ptrtogbtype[FC32.p] = FC32 global FC64 = GBType{ComplexF32}("GxB_FC64") ptrtogbtype[FC64.p] = FC64 diff --git a/src/libutils.jl b/src/libutils.jl index 40dde32e..59425a8d 100644 --- a/src/libutils.jl +++ b/src/libutils.jl @@ -65,7 +65,7 @@ end "Load a global constant from SSGrB, optionally specify the resulting pointer type." -function load_global(str, type = Cvoid) +function load_global(str, type::Type{Ptr{T}} = Ptr{Nothing}) where {T} x = try dlsym(SSGraphBLAS_jll.libgraphblas_handle, str) @@ -73,23 +73,13 @@ function load_global(str, type = Cvoid) @warn "Symbol not available " * str return C_NULL end - return unsafe_load(cglobal(x, Ptr{type})) + return unsafe_load(cglobal(x, type)) end +load_global(str, type) = load_global(str, Ptr{type}) + isGxB(name) = name[1:3] == "GxB" isGrB(name) = name[1:3] == "GrB" -""" - _print_unsigned_as_signed() - -The SuiteSparseGraphBLAS index, GrB_Index, is an alias for UInt64. Julia prints values of -this type in hex, so this can be used to change the printing method to decimal. - -This is not recommended for general use and will likely be removed once better printing is -added to this package. -""" -function _print_unsigned_as_signed() - eval(:(Base.show(io::IO, a::Unsigned) = print(io, Int(a)))) -end function splitconstant(str) return String.(split(str, "_")) diff --git a/src/operations/ewise.jl b/src/operations/ewise.jl index b30b2807..f2c9b89d 100644 --- a/src/operations/ewise.jl +++ b/src/operations/ewise.jl @@ -64,13 +64,13 @@ function emul!( size(w) == size(u) == size(v) || throw(DimensionMismatch()) op = getoperator(op, optype(u, v)) accum = getoperator(accum, eltype(w)) - if op isa libgb.GrB_Semiring + if op isa TypedSemiring libgb.GrB_Vector_eWiseMult_Semiring(w, mask, accum, op, u, v, desc) return w - elseif op isa libgb.GrB_Monoid + elseif op isa TypedMonoid libgb.GrB_Vector_eWiseMult_Monoid(w, mask, accum, op, u, v, desc) return w - elseif op isa libgb.GrB_BinaryOp + elseif op isa TypedBinaryOperator libgb.GrB_Vector_eWiseMult_BinaryOp(w, mask, accum, op, u, v, desc) return w else @@ -107,13 +107,13 @@ function emul!( A, desc, B = _handletranspose(A, desc, B) op = getoperator(op, optype(A, B)) accum = getoperator(accum, eltype(C)) - if op isa libgb.GrB_Semiring + if op isa TypedSemiring libgb.GrB_Matrix_eWiseMult_Semiring(C, mask, accum, op, A, B, desc) return C - elseif op isa libgb.GrB_Monoid + elseif op isa TypedMonoid libgb.GrB_Matrix_eWiseMult_Monoid(C, mask, accum, op, A, B, desc) return C - elseif op isa libgb.GrB_BinaryOp + elseif op isa TypedBinaryOperator libgb.GrB_Matrix_eWiseMult_BinaryOp(C, mask, accum, op, A, B, desc) return C else @@ -202,13 +202,13 @@ function eadd!( size(w) == size(u) == size(v) || throw(DimensionMismatch()) op = getoperator(op, optype(u, v)) accum = getoperator(accum, eltype(w)) - if op isa libgb.GrB_Semiring + if op isa TypedSemiring libgb.GrB_Vector_eWiseAdd_Semiring(w, mask, accum, op, u, v, desc) return w - elseif op isa libgb.GrB_Monoid + elseif op isa TypedMonoid libgb.GrB_Vector_eWiseAdd_Monoid(w, mask, accum, op, u, v, desc) return w - elseif op isa libgb.GrB_BinaryOp + elseif op isa TypedBinaryOperator libgb.GrB_Vector_eWiseAdd_BinaryOp(w, mask, accum, op, u, v, desc) return w else @@ -245,13 +245,13 @@ function eadd!( A, desc, B = _handletranspose(A, desc, B) op = getoperator(op, optype(A, B)) accum = getoperator(accum, eltype(C)) - if op isa libgb.GrB_Semiring + if op isa TypedSemiring libgb.GrB_Matrix_eWiseAdd_Semiring(C, mask, accum, op, A, B, desc) return C - elseif op isa libgb.GrB_Monoid + elseif op isa TypedMonoid libgb.GrB_Matrix_eWiseAdd_Monoid(C, mask, accum, op, A, B, desc) return C - elseif op isa libgb.GrB_BinaryOp + elseif op isa TypedBinaryOperator libgb.GrB_Matrix_eWiseAdd_BinaryOp(C, mask, accum, op, A, B, desc) return C else diff --git a/src/operations/kronecker.jl b/src/operations/kronecker.jl index a815e25d..ddbd0944 100644 --- a/src/operations/kronecker.jl +++ b/src/operations/kronecker.jl @@ -16,11 +16,11 @@ function LinearAlgebra.kron!( op = getoperator(op, optype(A, B)) A, desc, B = _handletranspose(A, desc, B) accum = getoperator(accum, eltype(C)) - if op isa libgb.GrB_BinaryOp + if op isa TypedBinaryOperator libgb.GxB_kron(C, mask, accum, op, A, B, desc) - elseif op isa libgb.GrB_Monoid + elseif op isa TypedMonoid libgb.GrB_Matrix_kronecker_Monoid(C, mask, accum, op, A, B, desc) - elseif op isa libgb.GrB_Semiring + elseif op isa TypedSemiring libgb.GrB_Matrix_kronecker_Semiring(C, mask, accum, op, A, B, desc) else throw(ArgumentError("$op is not a valid monoid binary op or semiring.")) diff --git a/src/operations/mul.jl b/src/operations/mul.jl index cd258cd9..e5d72a40 100644 --- a/src/operations/mul.jl +++ b/src/operations/mul.jl @@ -14,7 +14,7 @@ function LinearAlgebra.mul!( op = getoperator(op, optype(A, B)) accum = getoperator(accum, eltype(C)) A, desc, B = _handletranspose(A, desc, B) - op isa libgb.GrB_Semiring || throw(ArgumentError("$op is not a valid libgb.GrB_Semiring")) + op isa TypedSemiring || throw(ArgumentError("$op is not a valid TypedSemiring")) libgb.GrB_mxm(C, mask, accum, op, A, B, desc) return C end @@ -34,7 +34,7 @@ function LinearAlgebra.mul!( op = getoperator(op, optype(u, A)) accum = getoperator(accum, eltype(w)) u, desc, A = _handletranspose(u, desc, A) - op isa libgb.GrB_Semiring || throw(ArgumentError("$op is not a valid libgb.GrB_Semiring")) + op isa TypedSemiring || throw(ArgumentError("$op is not a valid TypedSemiring")) libgb.GrB_vxm(w, mask, accum, op, u, A, desc) return w end @@ -54,7 +54,7 @@ function LinearAlgebra.mul!( op = getoperator(op, optype(A, u)) accum = getoperator(accum, eltype(w)) A, desc, u = _handletranspose(A, desc, u) - op isa libgb.GrB_Semiring || throw(ArgumentError("$op is not a valid libgb.GrB_Semiring")) + op isa TypedSemiring || throw(ArgumentError("$op is not a valid TypedSemiring")) libgb.GrB_mxv(w, mask, accum, op, A, u, desc) return w end diff --git a/src/operations/operationutils.jl b/src/operations/operationutils.jl index cbaca53c..afdc79ef 100644 --- a/src/operations/operationutils.jl +++ b/src/operations/operationutils.jl @@ -18,18 +18,18 @@ end optype(::GBArray{T}, ::GBArray{U}) where {T, U} = optype(T, U) -function inferoutputtype(A::GBArray{T}, B::GBArray{U}, op::AbstractOp) where {T, U} - t = optype(A, B) +function inferoutputtype(::GBArray{T}, ::GBArray{U}, op::AbstractOp) where {T, U} + t = optype(T, U) return ztype(op, t) end function inferoutputtype(::GBArray{T}, op::AbstractOp) where {T} return ztype(op, T) end -function inferoutputtype(::GBArray{T}, op) where {T} - return ztype(op) +function inferoutputtype(::GBArray{T}, ::AbstractTypedOp{Z}) where {T, Z} + return Z end -function inferoutputtype(::GBArray{T}, ::GBArray{U}, op) where {T, U} - return ztype(op) +function inferoutputtype(::GBArray{T}, ::GBArray{U}, ::AbstractTypedOp{Z}) where {T, U, Z} + return Z end function _handlectx(ctx, ctxvar, default = nothing) if ctx === nothing || ctx === missing diff --git a/src/operators/binaryops.jl b/src/operators/binaryops.jl index d8a3d4cd..2ea32c62 100644 --- a/src/operators/binaryops.jl +++ b/src/operators/binaryops.jl @@ -1,11 +1,8 @@ baremodule BinaryOps using ..Types + using ..SuiteSparseGraphBLAS: TypedUnaryOperator end -const BinaryUnion = Union{AbstractBinaryOp, libgb.GrB_BinaryOp} - -function _binarynames(name) - -end +const BinaryUnion = Union{AbstractBinaryOp, TypedBinaryOperator} #TODO: Rewrite function _createbinaryops() @@ -78,22 +75,22 @@ function BinaryOp(name) if isGxB(name) || isGrB(name) #Built-in is immutable, no finalizer structquote = quote struct $containername <: AbstractBinaryOp - pointers::Dict{DataType, libgb.GrB_BinaryOp} + typedops::Dict{DataType, TypedBinaryOperator} name::String - $containername() = new(Dict{DataType, libgb.GrB_BinaryOp}(), $name) + $containername() = new(Dict{DataType, TypedBinaryOperator}(), $name) end end else #UDF is mutable for finalizer structquote = quote mutable struct $containername <: AbstractBinaryOp - pointers::Dict{DataType, libgb.GrB_BinaryOp} + typedops::Dict{DataType, TypedBinaryOperator} name::String function $containername() - b = new(Dict{DataType, libgb.GrB_BinaryOp}(), $name) + b = new(Dict{DataType, TypedBinaryOperator}(), $name) function f(binaryop) - for k ∈ keys(binaryop.pointers) - libgb.GrB_BinaryOp_free(Ref(binaryop.pointers[k])) - delete!(binaryop.pointers, k) + for k ∈ keys(binaryop.typedops) + libgb.GrB_BinaryOp_free(Ref(binaryop.typedops[k])) + delete!(binaryop.typedops, k) end end return finalizer(f, b) @@ -127,7 +124,7 @@ function _addbinaryop( opref = Ref{libgb.GrB_BinaryOp}() binaryopfn_C = @cfunction($binaryopfn, Cvoid, (Ptr{T}, Ref{U}, Ref{V})) libgb.GB_BinaryOp_new(opref, binaryopfn_C, ztype, xtype, ytype, op.name) - op.pointers[U] = opref[] + op.typedops[U] = TypedBinaryOperator{xtype, ytype, ztype}(opref[]) return nothing end @@ -346,42 +343,40 @@ function _load(binary::AbstractBinaryOp) ] name = binary.name if name ∈ booleans - binary.pointers[Bool] = load_global(name * "_BOOL") + binary.typedops[Bool] = TypedBinaryOperator(load_global(name * "_BOOL", libgb.GrB_BinaryOp)) end if name ∈ integers - binary.pointers[Int8] = load_global(name * "_INT8") - binary.pointers[Int16] = load_global(name * "_INT16") - binary.pointers[Int32] = load_global(name * "_INT32") - binary.pointers[Int64] = load_global(name * "_INT64") + binary.typedops[Int8] = TypedBinaryOperator(load_global(name * "_INT8", libgb.GrB_BinaryOp)) + binary.typedops[Int16] = TypedBinaryOperator(load_global(name * "_INT16", libgb.GrB_BinaryOp)) + binary.typedops[Int32] = TypedBinaryOperator(load_global(name * "_INT32", libgb.GrB_BinaryOp)) + binary.typedops[Int64] = TypedBinaryOperator(load_global(name * "_INT64", libgb.GrB_BinaryOp)) end if name ∈ unsignedintegers - binary.pointers[UInt8] = load_global(name * "_UINT8") - binary.pointers[UInt16] = load_global(name * "_UINT16") - binary.pointers[UInt32] = load_global(name * "_UINT32") - binary.pointers[UInt64] = load_global(name * "_UINT64") + binary.typedops[UInt8] = TypedBinaryOperator(load_global(name * "_UINT8", libgb.GrB_BinaryOp)) + binary.typedops[UInt16] = TypedBinaryOperator(load_global(name * "_UINT16", libgb.GrB_BinaryOp)) + binary.typedops[UInt32] = TypedBinaryOperator(load_global(name * "_UINT32", libgb.GrB_BinaryOp)) + binary.typedops[UInt64] = TypedBinaryOperator(load_global(name * "_UINT64", libgb.GrB_BinaryOp)) end if name ∈ floats - binary.pointers[Float32] = load_global(name * "_FP32") - binary.pointers[Float64] = load_global(name * "_FP64") + binary.typedops[Float32] = TypedBinaryOperator(load_global(name * "_FP32", libgb.GrB_BinaryOp)) + binary.typedops[Float64] = TypedBinaryOperator(load_global(name * "_FP64", libgb.GrB_BinaryOp)) end if name ∈ positionals - binary.pointers[Any] = load_global(name * "_INT64") + binary.typedops[Any] = TypedBinaryOperator(load_global(name * "_INT64", libgb.GrB_BinaryOp)) end name = "GxB_" * name[5:end] if name ∈ complexes - binary.pointers[ComplexF32] = load_global(name * "_FC32") - binary.pointers[ComplexF64] = load_global(name * "_FC64") + binary.typedops[ComplexF32] = TypedBinaryOperator(load_global(name * "_FC32", libgb.GrB_BinaryOp)) + binary.typedops[ComplexF64] = TypedBinaryOperator(load_global(name * "_FC64", libgb.GrB_BinaryOp)) end end -Base.show(io::IO, ::MIME"text/plain", u::libgb.GrB_BinaryOp) = gxbprint(io, u) - -xtype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_xtype(op)]) -ytype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ytype(op)]) -ztype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ztype(op)]) +ztype(::TypedBinaryOperator{X, Y, Z}) where {X, Y, Z} = Z +xtype(::TypedBinaryOperator{X, Y, Z}) where {X, Y, Z} = X +ytype(::TypedBinaryOperator{X, Y, Z}) where {X, Y, Z} = Y """ First argument: `f(x::T,y::T)::T = x` diff --git a/src/operators/libgbops.jl b/src/operators/libgbops.jl new file mode 100644 index 00000000..0a6ae15e --- /dev/null +++ b/src/operators/libgbops.jl @@ -0,0 +1,25 @@ +ztype(op::libgb.GrB_UnaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_UnaryOp_ztype(op)]) +xtype(op::libgb.GrB_UnaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_UnaryOp_xtype(op)]) +Base.show(io::IO, ::MIME"text/plain", u::libgb.GrB_UnaryOp) = gxbprint(io, u) + +xtype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_xtype(op)]) +ytype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ytype(op)]) +ztype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ztype(op)]) +Base.show(io::IO, ::MIME"text/plain", u::libgb.GrB_BinaryOp) = gxbprint(io, u) + + +operator(monoid::libgb.GrB_Monoid) = libgb.GxB_Monoid_operator(monoid) +xtype(monoid::libgb.GrB_Monoid) = xtype(operator(monoid)) +ytype(monoid::libgb.GrB_Monoid) = ytype(operator(monoid)) +ztype(monoid::libgb.GrB_Monoid) = ztype(operator(monoid)) +Base.show(io::IO, ::MIME"text/plain", m::libgb.GrB_Monoid) = gxbprint(io, m) + + + + +multiplyop(rig::libgb.GrB_Semiring) = libgb.GxB_Semiring_multiply(rig) +addop(rig::libgb.GrB_Semiring) = libgb.GxB_Semiring_add(rig) +xtype(rig::libgb.GrB_Semiring) = xtype(multiplyop(rig)) +ytype(rig::libgb.GrB_Semiring) = ytype(multiplyop(rig)) +ztype(rig::libgb.GrB_Semiring) = ztype(addop(rig)) +Base.show(io::IO, ::MIME"text/plain", s::libgb.GrB_Semiring) = gxbprint(io, s) diff --git a/src/operators/monoids.jl b/src/operators/monoids.jl index d6dec4f3..52f491e6 100644 --- a/src/operators/monoids.jl +++ b/src/operators/monoids.jl @@ -1,7 +1,7 @@ baremodule Monoids using ..Types end -const MonoidUnion = Union{AbstractMonoid, libgb.GrB_Monoid} +const MonoidUnion = Union{AbstractMonoid, TypedMonoid} function _monoidnames(name) if isGxB(name) || isGrB(name) @@ -42,22 +42,22 @@ function Monoid(name) if isGxB(name) || isGrB(name) #Built-ins are immutable structquote = quote struct $containername <: AbstractMonoid - pointers::Dict{DataType, libgb.GrB_Monoid} + typedops::Dict{DataType, TypedMonoid} name::String - $containername() = new(Dict{DataType, libgb.GrB_Monoid}(), $name) + $containername() = new(Dict{DataType, TypedMonoid}(), $name) end end else #UDFs are mutable for finalizing structquote = quote mutable struct $containername <: AbstractMonoid - pointers::Dict{DataType, libgb.GrB_Monoid} + typedops::Dict{DataType, TypedMonoid} name::String function $containername() - m = new(Dict{DataType, libgb.GrB_Monoid}(), $name) + m = new(Dict{DataType, TypedMonoid}(), $name) function f(monoid) - for k ∈ keys(monoid.pointers) - libgb.GrB_Monoid_free(Ref(monoid.pointers[k])) - delete!(monoid.pointers, k) + for k ∈ keys(monoid.typedops) + libgb.GrB_Monoid_free(Ref(monoid.typedops[k])) + delete!(monoid.typedops, k) end end return finalizer(f, m) @@ -95,7 +95,7 @@ function _addmonoid(op::AbstractMonoid, binop::BinaryUnion, id::T, terminal = no else libgb.monoidtermnew[Any](monref, binop, Ptr{Cvoid}(id), Ptr{Cvoid}(terminal)) end - op.pointers[T] = monref[] + op.typedops[T] = TypedMonoid{xtype(binop), ytype(binop), ztype(binop)}(monref[]) return nothing end @@ -142,49 +142,49 @@ function _load(monoid::AbstractMonoid) if name ∈ booleans constname = name * ((isGxB(name) ? "_BOOL_MONOID" : "_MONOID_BOOL")) - monoid.pointers[Bool] = load_global(constname) + monoid.typedops[Bool] = TypedMonoid(load_global(constname, libgb.GrB_Monoid)) end if name ∈ integers - monoid.pointers[Int8] = - load_global(name * (isGxB(name) ? "_INT8_MONOID" : "_MONOID_INT8")) - monoid.pointers[Int16] = - load_global(name * (isGxB(name) ? "_INT16_MONOID" : "_MONOID_INT16")) - monoid.pointers[Int32] = - load_global(name * (isGxB(name) ? "_INT32_MONOID" : "_MONOID_INT32")) - monoid.pointers[Int64] = - load_global(name * (isGxB(name) ? "_INT64_MONOID" : "_MONOID_INT64")) + monoid.typedops[Int8] = + TypedMonoid(load_global(name * (isGxB(name) ? "_INT8_MONOID" : "_MONOID_INT8"), libgb.GrB_Monoid)) + monoid.typedops[Int16] = + TypedMonoid(load_global(name * (isGxB(name) ? "_INT16_MONOID" : "_MONOID_INT16"), libgb.GrB_Monoid)) + monoid.typedops[Int32] = + TypedMonoid(load_global(name * (isGxB(name) ? "_INT32_MONOID" : "_MONOID_INT32"), libgb.GrB_Monoid)) + monoid.typedops[Int64] = + TypedMonoid(load_global(name * (isGxB(name) ? "_INT64_MONOID" : "_MONOID_INT64"), libgb.GrB_Monoid)) end if name ∈ unsignedintegers - monoid.pointers[UInt8] = - load_global(name * (isGxB(name) ? "_UINT8_MONOID" : "_MONOID_UINT8")) - monoid.pointers[UInt16] = - load_global(name * (isGxB(name) ? "_UINT16_MONOID" : "_MONOID_UINT16")) - monoid.pointers[UInt32] = - load_global(name * (isGxB(name) ? "_UINT32_MONOID" : "_MONOID_UINT32")) - monoid.pointers[UInt64] = - load_global(name * (isGxB(name) ? "_UINT64_MONOID" : "_MONOID_UINT64")) + monoid.typedops[UInt8] = + TypedMonoid(load_global(name * (isGxB(name) ? "_UINT8_MONOID" : "_MONOID_UINT8"), libgb.GrB_Monoid)) + monoid.typedops[UInt16] = + TypedMonoid(load_global(name * (isGxB(name) ? "_UINT16_MONOID" : "_MONOID_UINT16"), libgb.GrB_Monoid)) + monoid.typedops[UInt32] = + TypedMonoid(load_global(name * (isGxB(name) ? "_UINT32_MONOID" : "_MONOID_UINT32"), libgb.GrB_Monoid)) + monoid.typedops[UInt64] = + TypedMonoid(load_global(name * (isGxB(name) ? "_UINT64_MONOID" : "_MONOID_UINT64"), libgb.GrB_Monoid)) end if name ∈ floats - monoid.pointers[Float32] = - load_global(name * (isGxB(name) ? "_FP32_MONOID" : "_MONOID_FP32")) - monoid.pointers[Float64] = - load_global(name * (isGxB(name) ? "_FP64_MONOID" : "_MONOID_FP64")) + monoid.typedops[Float32] = + TypedMonoid(load_global(name * (isGxB(name) ? "_FP32_MONOID" : "_MONOID_FP32"), libgb.GrB_Monoid)) + monoid.typedops[Float64] = + TypedMonoid(load_global(name * (isGxB(name) ? "_FP64_MONOID" : "_MONOID_FP64"), libgb.GrB_Monoid)) end name = "GxB_" * name[5:end] if name ∈ complexes #Complex monoids are always GxB, so "_MONOID" is always at the end. - monoid.pointers[ComplexF32] = load_global(name * "_FC32_MONOID") - monoid.pointers[ComplexF64] = load_global(name * "_FC64_MONOID") + monoid.typedops[ComplexF32] = TypedMonoid(load_global(name * "_FC32_MONOID", libgb.GrB_Monoid)) + monoid.typedops[ComplexF64] = TypedMonoid(load_global(name * "_FC64_MONOID", libgb.GrB_Monoid)) end end -Base.show(io::IO, ::MIME"text/plain", m::libgb.GrB_Monoid) = gxbprint(io, m) -operator(monoid::libgb.GrB_Monoid) = libgb.GxB_Monoid_operator(monoid) -xtype(monoid::libgb.GrB_Monoid) = xtype(operator(monoid)) -ytype(monoid::libgb.GrB_Monoid) = ytype(operator(monoid)) -ztype(monoid::libgb.GrB_Monoid) = ztype(operator(monoid)) + + +ztype(::TypedMonoid{X, Y, Z}) where {X, Y, Z} = Z +xtype(::TypedMonoid{X, Y, Z}) where {X, Y, Z} = X +ytype(::TypedMonoid{X, Y, Z}) where {X, Y, Z} = Y """ Minimum monoid: `f(x::ℝ, y::ℝ)::ℝ = min(x, y)` diff --git a/src/operators/operatorutils.jl b/src/operators/operatorutils.jl index 075e9d82..982d8b3a 100644 --- a/src/operators/operatorutils.jl +++ b/src/operators/operatorutils.jl @@ -33,7 +33,7 @@ function getoperator(op, t) end end -_isloaded(o::AbstractOp) = !isempty(o.pointers) +_isloaded(o::AbstractOp) = !isempty(o.typedops) """ validtypes(operator::AbstractOp)::Vector{DataType} @@ -50,15 +50,15 @@ function validtypes(o::AbstractOp) if !_isloaded(o) _load(o) end - return collect(keys(o.pointers)) + return collect(keys(o.typedops)) end function Base.getindex(o::AbstractOp, t::DataType) _isloaded(o) || _load(o) - if Any ∈ keys(o.pointers) - getindex(o.pointers, Any) + if Any ∈ keys(o.typedops) + getindex(o.typedops, Any) else - getindex(o.pointers, t) + getindex(o.typedops, t) end end diff --git a/src/operators/semirings.jl b/src/operators/semirings.jl index a06c9881..64be6cea 100644 --- a/src/operators/semirings.jl +++ b/src/operators/semirings.jl @@ -2,7 +2,7 @@ baremodule Semirings using ..Types end -SemiringUnion = Union{AbstractSemiring, libgb.GrB_Semiring} +SemiringUnion = Union{AbstractSemiring, TypedSemiring} function _semiringnames(name) if isGxB(name) || isGrB(name) @@ -242,22 +242,22 @@ function Semiring(name) if isGxB(name) || isGrB(name) structquote = quote struct $containername <: AbstractSemiring - pointers::Dict{DataType, libgb.GrB_Semiring} + typedops::Dict{DataType, TypedSemiring} name::String - $containername() = new(Dict{DataType, libgb.GrB_Semiring}(), $name) + $containername() = new(Dict{DataType, TypedSemiring}(), $name) end end else structquote = quote mutable struct $containername <: AbstractSemiring - pointers::Dict{DataType, libgb.GrB_Semiring} + typedops::Dict{DataType, TypedSemiring} name::String function $containername() - r = new(Dict{DataType, libgb.GrB_Semiring}(), $name) + r = new(Dict{DataType, TypedSemiring}(), $name) function f(rig) - for k ∈ keys(rig.pointers) - libgb.GrB_Semiring_free(Ref(rig.pointers[k])) - delete!(rig.pointers, k) + for k ∈ keys(rig.typedops) + libgb.GrB_Semiring_free(Ref(rig.typedops[k])) + delete!(rig.typedops, k) end end return finalizer(f, r) @@ -275,15 +275,15 @@ function Semiring(name) end #Add typed ⊕ and ⊗ to semiring -function _addsemiring(rig::AbstractSemiring, add::libgb.GrB_Monoid, mul::libgb.GrB_BinaryOp) - rigref = Ref{libgb.GrB_Semiring}() +function _addsemiring(rig::AbstractSemiring, add::TypedMonoid, mul::TypedBinaryOperator) + rigref = Ref{TypedSemiring}() libgb.GrB_Semiring_new(rigref, add, mul) - rig.pointers[xtype(add)] = rigref[] + rig.typedops[xtype(add)] = TypedSemiring(rigref[]) return nothing end #New semiring with typed ⊕ and ⊗ -function Semiring(name::String, add::libgb.GrB_Monoid, mul::libgb.GrB_BinaryOp) +function Semiring(name::String, add::TypedMonoid, mul::TypedBinaryOperator) rig = Semiring(name) _addsemiring(rig, add, mul) return rig @@ -301,6 +301,10 @@ function Semiring(name::String, add::AbstractMonoid, mul::AbstractBinaryOp) return rig end +ztype(::TypedSemiring{X, Y, Z}) where {X, Y, Z} = Z +xtype(::TypedSemiring{X, Y, Z}) where {X, Y, Z} = X +ytype(::TypedSemiring{X, Y, Z}) where {X, Y, Z} = Y + function _load(rig::AbstractSemiring) booleans = ["GxB_LOR_FIRST", "GxB_LAND_FIRST", @@ -846,43 +850,34 @@ function _load(rig::AbstractSemiring) ] name = rig.name if name ∈ booleans - rig.pointers[Bool] = load_global(name * "_BOOL") + rig.typedops[Bool] = TypedSemiring(load_global(name * "_BOOL", libgb.GrB_Semiring)) end if name ∈ integers - rig.pointers[Int8] =load_global(name * "_INT8") - rig.pointers[Int16] = load_global(name * "_INT16") - rig.pointers[Int32] = load_global(name * "_INT32") - rig.pointers[Int64] = load_global(name * "_INT64") + rig.typedops[Int8] = TypedSemiring(load_global(name * "_INT8", libgb.GrB_Semiring)) + rig.typedops[Int16] = TypedSemiring(load_global(name * "_INT16", libgb.GrB_Semiring)) + rig.typedops[Int32] = TypedSemiring(load_global(name * "_INT32", libgb.GrB_Semiring)) + rig.typedops[Int64] = TypedSemiring(load_global(name * "_INT64", libgb.GrB_Semiring)) end if name ∈ unsignedintegers - rig.pointers[UInt8] =load_global(name * "_UINT8") - rig.pointers[UInt16] = load_global(name * "_UINT16") - rig.pointers[UInt32] = load_global(name * "_UINT32") - rig.pointers[UInt64] = load_global(name * "_UINT64") + rig.typedops[UInt8] = TypedSemiring(load_global(name * "_UINT8", libgb.GrB_Semiring)) + rig.typedops[UInt16] = TypedSemiring(load_global(name * "_UINT16", libgb.GrB_Semiring)) + rig.typedops[UInt32] = TypedSemiring(load_global(name * "_UINT32", libgb.GrB_Semiring)) + rig.typedops[UInt64] = TypedSemiring(load_global(name * "_UINT64", libgb.GrB_Semiring)) end if name ∈ floats - rig.pointers[Float32] = load_global(name * "_FP32") - rig.pointers[Float64] = load_global(name * "_FP64") + rig.typedops[Float32] = TypedSemiring(load_global(name * "_FP32", libgb.GrB_Semiring)) + rig.typedops[Float64] = TypedSemiring(load_global(name * "_FP64", libgb.GrB_Semiring)) end if name ∈ positionals - rig.pointers[Any] = load_global(name * "_INT64") + rig.typedops[Any] = TypedSemiring(load_global(name * "_INT64", libgb.GrB_Semiring)) end name = replace(name, "GrB_" => "GxB_") name = replace(name, "_SEMIRING" => "") if name ∈ complexes - rig.pointers[ComplexF32] = load_global(name * "_FC32") - rig.pointers[ComplexF64] = load_global(name * "_FC64") + rig.typedops[ComplexF32] = TypedSemiring(load_global(name * "_FC32", libgb.GrB_Semiring)) + rig.typedops[ComplexF64] = TypedSemiring(load_global(name * "_FC64", libgb.GrB_Semiring)) end end - -Base.show(io::IO, ::MIME"text/plain", s::libgb.GrB_Semiring) = gxbprint(io, s) - -multiplyop(rig::libgb.GrB_Semiring) = libgb.GxB_Semiring_multiply(rig) -addop(rig::libgb.GrB_Semiring) = libgb.GxB_Semiring_add(rig) - -xtype(rig::libgb.GrB_Semiring) = xtype(multiplyop(rig)) -ytype(rig::libgb.GrB_Semiring) = ytype(multiplyop(rig)) -ztype(rig::libgb.GrB_Semiring) = ztype(addop(rig)) diff --git a/src/operators/unaryops.jl b/src/operators/unaryops.jl index 09386c0f..2b879da4 100644 --- a/src/operators/unaryops.jl +++ b/src/operators/unaryops.jl @@ -1,9 +1,10 @@ baremodule UnaryOps using ..Types + using ..SuiteSparseGraphBLAS: TypedUnaryOperator end -const UnaryUnion = Union{AbstractUnaryOp, libgb.GrB_UnaryOp} +const UnaryUnion = Union{AbstractUnaryOp, TypedUnaryOperator} #TODO: Rewrite function _createunaryops() @@ -75,22 +76,22 @@ function UnaryOp(name) if isGxB(name) || isGrB(name) structquote = quote struct $tname <: AbstractUnaryOp - pointers::Dict{DataType, libgb.GrB_UnaryOp} + typedops::Dict{DataType, TypedUnaryOperator} name::String - $tname() = new(Dict{DataType, libgb.GrB_UnaryOp}(), $name) + $tname() = new(Dict{DataType, TypedUnaryOperator}(), $name) end end else #If it's a UDF we need a mutable for finalizing purposes. structquote = quote mutable struct $tname <: AbstractUnaryOp - pointers::Dict{DataType, libgb.GrB_UnaryOp} + typedops::Dict{DataType, TypedUnaryOperator} name::String function $tname() - u = new(Dict{DataType, libgb.GrB_UnaryOp}(), $name) + u = new(Dict{DataType, TypedUnaryOperator}(), $name) function f(unaryop) - for k ∈ keys(unaryop.pointers) - libgb.GrB_UnaryOp_free(Ref(unaryop.pointers[k])) - delete!(unaryop.pointers, k) + for k ∈ keys(unaryop.typedops) + libgb.GrB_UnaryOp_free(Ref(unaryop.typedops[k])) + delete!(unaryop.typedops, k) end end return finalizer(f, u) @@ -117,7 +118,7 @@ function _addunaryop(op::AbstractUnaryOp, fn::Function, ztype::GBType{T}, xtype: opref = Ref{libgb.GrB_UnaryOp}() unaryopfn_C = @cfunction($unaryopfn, Cvoid, (Ptr{T}, Ref{U})) libgb.GB_UnaryOp_new(opref, unaryopfn_C, ztype, xtype, op.name) - op.pointers[U] = opref[] + op.typedops[U] = TypedUnaryOperator{xtype, ztype}(opref[]) return nothing end @@ -146,10 +147,11 @@ end function UnaryOp(name::String, fn::Function, type::Vector{DataType}) return UnaryOp(name, fn, type, type) end -#Construct it using the built in primitives. +#Construct it using all the built in primitives. function UnaryOp(name::String, fn::Function) return UnaryOp(name, fn, valid_vec) end + function _load(unaryop::AbstractUnaryOp) booleans = ["GrB_IDENTITY", "GrB_AINV", "GrB_MINV", "GxB_LNOT", "GxB_ONE", "GrB_ABS"] integers = [ @@ -255,41 +257,40 @@ function _load(unaryop::AbstractUnaryOp) name = unaryop.name if name ∈ booleans constname = name * "_BOOL" - unaryop.pointers[Bool] = load_global(constname) + unaryop.typedops[Bool] = TypedUnaryOperator(load_global(constname, libgb.GrB_UnaryOp)) end if name ∈ integers - unaryop.pointers[Int8] = load_global(name * "_INT8") - unaryop.pointers[Int16] = load_global(name * "_INT16") - unaryop.pointers[Int32] = load_global(name * "_INT32") - unaryop.pointers[Int64] = load_global(name * "_INT64") + unaryop.typedops[Int8] = TypedUnaryOperator(load_global(name * "_INT8", libgb.GrB_UnaryOp)) + unaryop.typedops[Int16] = TypedUnaryOperator(load_global(name * "_INT16", libgb.GrB_UnaryOp)) + unaryop.typedops[Int32] = TypedUnaryOperator(load_global(name * "_INT32", libgb.GrB_UnaryOp)) + unaryop.typedops[Int64] = TypedUnaryOperator(load_global(name * "_INT64", libgb.GrB_UnaryOp)) end if name ∈ unsignedintegers - unaryop.pointers[UInt8] = load_global(name * "_UINT8") - unaryop.pointers[UInt16] = load_global(name * "_UINT16") - unaryop.pointers[UInt32] = load_global(name * "_UINT32") - unaryop.pointers[UInt64] = load_global(name * "_UINT64") + unaryop.typedops[UInt8] = TypedUnaryOperator(load_global(name * "_UINT8", libgb.GrB_UnaryOp)) + unaryop.typedops[UInt16] = TypedUnaryOperator(load_global(name * "_UINT16", libgb.GrB_UnaryOp)) + unaryop.typedops[UInt32] = TypedUnaryOperator(load_global(name * "_UINT32", libgb.GrB_UnaryOp)) + unaryop.typedops[UInt64] = TypedUnaryOperator(load_global(name * "_UINT64", libgb.GrB_UnaryOp)) end if name ∈ floats - unaryop.pointers[Float32] = load_global(name * "_FP32") - unaryop.pointers[Float64] = load_global(name * "_FP64") + unaryop.typedops[Float32] = TypedUnaryOperator(load_global(name * "_FP32", libgb.GrB_UnaryOp)) + unaryop.typedops[Float64] = TypedUnaryOperator(load_global(name * "_FP64", libgb.GrB_UnaryOp)) end if name ∈ positionals - unaryop.pointers[Any] = load_global(name * "_INT64") + unaryop.typedops[Any] = TypedUnaryOperator(load_global(name * "_INT64", libgb.GrB_UnaryOp)) end name = "GxB_" * name[5:end] if name ∈ complexes - unaryop.pointers[ComplexF32] = load_global(name * "_FC32") - unaryop.pointers[ComplexF64] = load_global(name * "_FC64") + unaryop.typedops[ComplexF32] = TypedUnaryOperator(load_global(name * "_FC32", libgb.GrB_UnaryOp)) + unaryop.typedops[ComplexF64] = TypedUnaryOperator(load_global(name * "_FC64", libgb.GrB_UnaryOp)) end end -ztype(op::libgb.GrB_UnaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_UnaryOp_ztype(op)]) -xtype(op::libgb.GrB_UnaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_UnaryOp_xtype(op)]) +ztype(::TypedUnaryOperator{I, O}) where {I, O} = O +xtype(::TypedUnaryOperator{I, O}) where {I, O} = I -Base.show(io::IO, ::MIME"text/plain", u::libgb.GrB_UnaryOp) = gxbprint(io, u) """ Identity: `z=x` diff --git a/src/types.jl b/src/types.jl index 66b20f9a..9094dc0a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,3 +1,35 @@ +struct TypedUnaryOperator{X, Z} <: AbstractTypedOp{Z} + p::libgb.GrB_UnaryOp +end +function TypedUnaryOperator(p::libgb.GrB_UnaryOp) + return TypedUnaryOperator{xtype(p), ztype(p)}(p) +end +Base.unsafe_convert(::Type{libgb.GrB_UnaryOp}, op::TypedUnaryOperator) = op.p + +struct TypedBinaryOperator{X, Y, Z} <: AbstractTypedOp{Z} + p::libgb.GrB_BinaryOp +end +function TypedBinaryOperator(p::libgb.GrB_BinaryOp) + return TypedBinaryOperator{xtype(p), ytype(p), ztype(p)}(p) +end +Base.unsafe_convert(::Type{libgb.GrB_BinaryOp}, op::TypedBinaryOperator) = op.p + +struct TypedMonoid{X, Y, Z} <: AbstractTypedOp{Z} + p::libgb.GrB_Monoid +end +function TypedMonoid(p::libgb.GrB_Monoid) + return TypedMonoid{xtype(p), ytype(p), ztype(p)}(p) +end +Base.unsafe_convert(::Type{libgb.GrB_Monoid}, op::TypedMonoid) = op.p + +struct TypedSemiring{X, Y, Z} <: AbstractTypedOp{Z} + p::libgb.GrB_Semiring +end +function TypedSemiring(p::libgb.GrB_Semiring) + return TypedSemiring{xtype(p), ytype(p), ztype(p)}(p) +end +Base.unsafe_convert(::Type{libgb.GrB_Semiring}, op::TypedSemiring) = op.p + """ Automatically generated type definitions. The struct definitions for built in monoids, binary ops, etc can be found here. @@ -5,6 +37,8 @@ built in monoids, binary ops, etc can be found here. module Types import ...SuiteSparseGraphBLAS: AbstractUnaryOp, AbstractMonoid, AbstractSelectOp, AbstractSemiring, AbstractBinaryOp, AbstractDescriptor + using ...SuiteSparseGraphBLAS: TypedUnaryOperator, TypedBinaryOperator, TypedMonoid, + TypedSemiring using ..libgb end diff --git a/test/chainrules/mulrules.jl b/test/chainrules/mulrules.jl index ebd7e165..e522a61f 100644 --- a/test/chainrules/mulrules.jl +++ b/test/chainrules/mulrules.jl @@ -3,19 +3,19 @@ @testset "Arithmetic Semiring" begin M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) - test_frule(mul, M, Y; check_inferred=false) - test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - test_rrule(mul, M, Y; check_inferred=false) - test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_frule(mul, M, Y) + test_frule(mul, M, Y, Semirings.PLUS_TIMES) + test_rrule(mul, M, Y) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES) end end @testset "Sparse" begin M = GBMatrix(sprand(100, 10, 0.25)) Y = GBMatrix(sprand(10, 0.1)) #using matrix for now until I work out transpose(v::GBVector) - test_frule(mul, M, Y; check_inferred=false) - test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - test_rrule(mul, M, Y; check_inferred=false) - test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_frule(mul, M, Y) + test_frule(mul, M, Y, Semirings.PLUS_TIMES) + test_rrule(mul, M, Y) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES) end end From 4a9a3a6b9482fb183c6aa265de0c1273212c8734 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sun, 11 Jul 2021 14:48:22 -0400 Subject: [PATCH 2/3] Fix type inference, remove with/context, test toml --- Project.toml | 9 +------ src/SuiteSparseGraphBLAS.jl | 8 ------ src/chainrules/mulrules.jl | 8 +++--- src/operations/ewise.jl | 40 ++++++++++-------------------- src/operations/kronecker.jl | 7 +++--- src/operations/map.jl | 26 ++++++++----------- src/operations/mul.jl | 20 +++++++-------- src/operations/operationutils.jl | 24 ++---------------- src/operations/reduce.jl | 6 ++--- src/operations/select.jl | 6 ++--- src/operations/transpose.jl | 7 ++++-- src/with.jl | 16 ------------ test/Project.toml | 11 ++++++++ test/chainrules/chainrulesutils.jl | 1 - test/chainrules/ewiserules.jl | 32 ++++++++++++------------ test/runtests.jl | 2 +- 16 files changed, 82 insertions(+), 141 deletions(-) delete mode 100644 src/with.jl create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 5bc3e979..b4c0b81e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,23 +6,16 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSGraphBLAS_jll = "7ed9a814-9cab-54e9-8e9e-d9e95b4d61b1" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] CEnum = "0.4" -ContextVariablesX = "0.1" +ChainRulesCore = "0.10" MacroTools = "0.5" SSGraphBLAS_jll = "5.1" julia = "1.6" -ChainRulesCore = "0.10" -ChainRulesTestUtils = "0.7" -FiniteDifferences = "0.12" diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index 40ed763a..0a82c843 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -7,7 +7,6 @@ using MacroTools using LinearAlgebra using Random: randsubseq, default_rng, AbstractRNG, GLOBAL_RNG using CEnum -using ContextVariablesX include("abstracts.jl") include("libutils.jl") include("lib/LibGraphBLAS.jl") @@ -70,13 +69,6 @@ const OperatorUnion = Union{ GrBOp } -#Context variables for the `with` function -@contextvar ctxop::OperatorUnion -@contextvar ctxmask::Union{GBArray, Ptr} = C_NULL -@contextvar ctxaccum::Union{BinaryUnion, Ptr} = C_NULL -@contextvar ctxdesc::Descriptor -include("with.jl") - include("scalar.jl") include("vector.jl") include("matrix.jl") diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl index 3344f05d..3c9a4bd9 100644 --- a/src/chainrules/mulrules.jl +++ b/src/chainrules/mulrules.jl @@ -14,8 +14,8 @@ function frule( B::GBMatOrTranspose, ::typeof(Semirings.PLUS_TIMES) ) - Ω = mul(A, B) - ∂Ω = mul(ΔA, B) + mul(A, ΔB) + Ω = mul(A, B, Semirings.PLUS_TIMES) + ∂Ω = mul(ΔA, B, Semirings.PLUS_TIMES) + mul(A, ΔB, Semirings.PLUS_TIMES) return Ω, ∂Ω end # Tests will not pass for this. For two reasons. @@ -32,8 +32,8 @@ function rrule( ::typeof(Semirings.PLUS_TIMES) ) function mulpullback(ΔΩ) - ∂A = mul(ΔΩ, B'; mask=A) - ∂B = mul(A', ΔΩ; mask=B) + ∂A = mul(ΔΩ, B', Semirings.PLUS_TIMES; mask=A) + ∂B = mul(A', ΔΩ, Semirings.PLUS_TIMES; mask=B) return NoTangent(), ∂A, ∂B, NoTangent() end return mul(A, B), mulpullback diff --git a/src/operations/ewise.jl b/src/operations/ewise.jl index f2c9b89d..4c6ce675 100644 --- a/src/operations/ewise.jl +++ b/src/operations/ewise.jl @@ -55,12 +55,12 @@ function emul!( w::GBVector, u::GBVector, v::GBVector, - op = nothing; + op = BinaryOps.TIMES; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES) + mask, accum, desc = _handlenothings(mask, accum, desc) size(w) == size(u) == size(v) || throw(DimensionMismatch()) op = getoperator(op, optype(u, v)) accum = getoperator(accum, eltype(w)) @@ -82,12 +82,11 @@ end function emul( u::GBVector, v::GBVector, - op = nothing; + op = BinaryOps.TIMES; mask = nothing, accum = nothing, desc = nothing ) - op = _handlectx(op, ctxop, BinaryOps.TIMES) t = inferoutputtype(u, v, op) w = GBVector{t}(size(u)) return emul!(w, u, v, op; mask , accum, desc) @@ -97,12 +96,12 @@ function emul!( C::GBMatrix, A::GBMatOrTranspose, B::GBMatOrTranspose, - op = nothing; + op = BinaryOps.TIMES; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES) + mask, accum, desc = _handlenothings(mask, accum, desc) size(C) == size(A) == size(B) || throw(DimensionMismatch()) A, desc, B = _handletranspose(A, desc, B) op = getoperator(op, optype(A, B)) @@ -125,12 +124,11 @@ end function emul( A::GBMatOrTranspose, B::GBMatOrTranspose, - op = nothing; + op = BinaryOps.TIMES; mask = nothing, accum = nothing, desc = nothing ) - op = _handlectx(op, ctxop, BinaryOps.TIMES) t = inferoutputtype(A, B, op) C = GBMatrix{t}(size(A)) return emul!(C, A, B, op; mask, accum, desc) @@ -193,12 +191,12 @@ function eadd!( w::GBVector, u::GBVector, v::GBVector, - op = nothing; + op = BinaryOps.PLUS; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS) + mask, accum, desc = _handlenothings(mask, accum, desc) size(w) == size(u) == size(v) || throw(DimensionMismatch()) op = getoperator(op, optype(u, v)) accum = getoperator(accum, eltype(w)) @@ -220,12 +218,11 @@ end function eadd( u::GBVector, v::GBVector, - op = nothing; + op = BinaryOps.PLUS; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS) t = inferoutputtype(u, v, op) w = GBVector{t}(size(u)) return eadd!(w, u, v, op; mask, accum, desc) @@ -235,12 +232,12 @@ function eadd!( C::GBMatrix, A::GBMatOrTranspose, B::GBMatOrTranspose, - op = nothing; + op = BinaryOps.PLUS; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS) + mask, accum, desc = _handlenothings(mask, accum, desc) size(C) == size(A) == size(B) || throw(DimensionMismatch()) A, desc, B = _handletranspose(A, desc, B) op = getoperator(op, optype(A, B)) @@ -263,19 +260,18 @@ end function eadd( A::GBMatOrTranspose, B::GBMatOrTranspose, - op = nothing; + op = BinaryOps.PLUS; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS) t = inferoutputtype(A, B, op) C = GBMatrix{t}(size(A)) return eadd!(C, A, B, op; mask, accum, desc) end function Base.:+(A::GBArray, B::GBArray) - eadd(A, B, nothing) + eadd(A, B, BinaryOps.PLUS) end function Base.:-(A::GBArray, B::GBArray) @@ -284,16 +280,6 @@ end #Elementwise Broadcasts ####################### -# default argument is missing to avoid `nothing` picking up the default default :). -function Base.broadcasted(::typeof(∪), A::GBArray, B::GBArray) - eadd(A, B, missing) -end - -# default argument is missing to avoid `nothing` picking up the default default :). -function Base.broadcasted(::typeof(∩), A::GBArray, B::GBArray) - emul(A, B, missing) -end - function Base.broadcasted(::typeof(*), A::GBArray, B::GBArray) emul(A, B, BinaryOps.TIMES) end diff --git a/src/operations/kronecker.jl b/src/operations/kronecker.jl index ddbd0944..1e305db9 100644 --- a/src/operations/kronecker.jl +++ b/src/operations/kronecker.jl @@ -7,12 +7,12 @@ function LinearAlgebra.kron!( C::GBMatOrTranspose, A::GBMatOrTranspose, B::GBMatOrTranspose, - op = nothing; + op = BinaryOps.TIMES; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES) + mask, accum, desc = _handlenothings(mask, accum, desc) op = getoperator(op, optype(A, B)) A, desc, B = _handletranspose(A, desc, B) accum = getoperator(accum, eltype(C)) @@ -47,12 +47,11 @@ Does not support `GBVector`s at this time. function LinearAlgebra.kron( A::GBMatOrTranspose, B::GBMatOrTranspose, - op = nothing; + op = BinaryOps.TIMES; mask = nothing, accum = nothing, desc = nothing ) - op = _handlectx(op, ctxop, BinaryOps.TIMES) t = inferoutputtype(A, B, op) C = GBMatrix{t}(size(A,1) * size(B, 1), size(A, 2) * size(B, 2)) kron!(C, A, B, op; mask, accum, desc) diff --git a/src/operations/map.jl b/src/operations/map.jl index 091d578a..3914c3d2 100644 --- a/src/operations/map.jl +++ b/src/operations/map.jl @@ -4,7 +4,7 @@ function Base.map!( op::UnaryUnion, C::GBArray, A::GBArray; mask = nothing, accum = nothing, desc = nothing ) - _, mask, accum, desc = _handlectx(op, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) op = getoperator(op, eltype(A)) accum = getoperator(accum, eltype(C)) if C isa GBVector && A isa GBVector @@ -33,7 +33,7 @@ function Base.map!( op::BinaryUnion, C::GBArray, x, A::GBArray; mask = nothing, accum = nothing, desc = nothing ) - _, mask, accum, desc = _handlectx(op, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) op = getoperator(op, optype(eltype(A), typeof(x))) accum = getoperator(accum, eltype(C)) if C isa GBVector && A isa GBVector @@ -63,7 +63,7 @@ function Base.map!( op::BinaryUnion, C::GBArray, A::GBArray, x; mask = nothing, accum = nothing, desc = nothing ) - _, mask, accum, desc = _handlectx(op, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) op = getoperator(op, optype(eltype(A), typeof(x))) accum = getoperator(accum, eltype(C)) if C isa GBVector && A isa GBVector @@ -89,27 +89,23 @@ function Base.map( return map!(op, similar(A, t), A, x; mask, accum, desc) end -function Base.broadcasted(::typeof(+), u::GBArray, x::valid_union; - mask = nothing, accum = nothing, desc = nothing +function Base.broadcasted(::typeof(+), u::GBArray, x::valid_union ) - map(BinaryOps.PLUS, u, x; mask, accum, desc) + map(BinaryOps.PLUS, u, x) end function Base.broadcasted( - ::typeof(+), x::valid_union, u::GBArray; - mask = nothing, accum = nothing, desc = nothing + ::typeof(+), x::valid_union, u::GBArray ) - map(BinaryOps.PLUS, x, u; mask, accum, desc) + map(BinaryOps.PLUS, x, u) end -function Base.broadcasted(::typeof(*), u::GBArray, x::valid_union; - mask = nothing, accum = nothing, desc = nothing +function Base.broadcasted(::typeof(*), u::GBArray, x::valid_union ) - map(BinaryOps.TIMES, u, x; mask, accum, desc) + map(BinaryOps.TIMES, u, x) end -function Base.broadcasted(::typeof(*), x::valid_union, u::GBArray; - mask = nothing, accum = nothing, desc = nothing +function Base.broadcasted(::typeof(*), x::valid_union, u::GBArray ) - map(BinaryOps.TIMES, x, u; mask, accum, desc) + map(BinaryOps.TIMES, x, u) end """ diff --git a/src/operations/mul.jl b/src/operations/mul.jl index e5d72a40..69763f54 100644 --- a/src/operations/mul.jl +++ b/src/operations/mul.jl @@ -2,12 +2,12 @@ function LinearAlgebra.mul!( C::GBMatrix, A::GBMatOrTranspose, B::GBMatOrTranspose, - op = nothing; + op = Semirings.PLUS_TIMES; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, Semirings.PLUS_TIMES) + mask, accum, desc = _handlenothings(mask, accum, desc) size(A, 2) == size(B, 1) || throw(DimensionMismatch("size(A, 2) != size(B, 1)")) size(A, 1) == size(C, 1) || throw(DimensionMismatch("size(A, 1) != size(C, 1)")) size(B, 2) == size(C, 2) || throw(DimensionMismatch("size(B, 2) != size(C, 2)")) @@ -23,12 +23,12 @@ function LinearAlgebra.mul!( w::GBVector, u::GBVector, A::GBMatOrTranspose, - op = nothing; + op = Semirings.PLUS_TIMES; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, Semirings.PLUS_TIMES) + mask, accum, desc = _handlenothings(mask, accum, desc) size(u, 1) == size(A, 1) || throw(DimensionMismatch("size(A, 1) != size(u)")) size(w, 1) == size(A, 2) || throw(DimensionMismatch("size(A, 2) != size(w)")) op = getoperator(op, optype(u, A)) @@ -43,12 +43,12 @@ function LinearAlgebra.mul!( w::GBVector, A::GBMatOrTranspose, u::GBVector, - op = nothing; + op = Semirings.PLUS_TIMES; mask = nothing, accum = nothing, desc = nothing ) - op, mask, accum, desc = _handlectx(op, mask, accum, desc, Semirings.PLUS_TIMES) + mask, accum, desc = _handlenothings(mask, accum, desc) size(u, 1) == size(A, 2) || throw(DimensionMismatch("size(A, 2) != size(u)")) size(w, 1) == size(A, 1) || throw(DimensionMismatch("size(A, 1) != size(w")) op = getoperator(op, optype(A, u)) @@ -86,12 +86,11 @@ The default semiring is the `+.*` semiring. function mul( A::GBArray, B::GBArray, - op = nothing; + op = Semirings.PLUS_TIMES; mask = nothing, accum = nothing, desc = nothing ) - op = _handlectx(op, ctxop, Semirings.PLUS_TIMES) t = inferoutputtype(A, B, op) if A isa GBVector && B isa GBMatOrTranspose C = GBVector{t}(size(B, 2)) @@ -106,11 +105,10 @@ end function Base.:*( A::GBArray, - B::GBArray, - op = nothing; + B::GBArray; mask = nothing, accum = nothing, desc = nothing ) - mul(A, B, op; mask, accum, desc) + mul(A, B, Semirings.PLUS_TIMES; mask, accum, desc) end diff --git a/src/operations/operationutils.jl b/src/operations/operationutils.jl index afdc79ef..258c4ed6 100644 --- a/src/operations/operationutils.jl +++ b/src/operations/operationutils.jl @@ -31,30 +31,10 @@ end function inferoutputtype(::GBArray{T}, ::GBArray{U}, ::AbstractTypedOp{Z}) where {T, U, Z} return Z end -function _handlectx(ctx, ctxvar, default = nothing) - if ctx === nothing || ctx === missing - ctx2 = get(ctxvar) - if ctx2 !== nothing - return something(ctx2) - elseif ctx !== missing - return default - else - throw(ArgumentError("This operation requires an operator specified by the `with` function.")) - end - else - return ctx - end -end -function _handlectx(op, mask, accum, desc, defaultop = nothing) - return ( - _handlectx(op, ctxop, defaultop), - _handlectx(mask, ctxmask, C_NULL), - _handlectx(accum, ctxaccum, C_NULL), - _handlectx(desc, ctxdesc, Descriptors.NULL) - ) +function _handlenothings(kwargs...) + return (x === nothing ? C_NULL : x for x in kwargs) end - """ xtype(op::GrBOp)::DataType diff --git a/src/operations/reduce.jl b/src/operations/reduce.jl index 8d6d8c44..cfc47566 100644 --- a/src/operations/reduce.jl +++ b/src/operations/reduce.jl @@ -2,7 +2,7 @@ function reduce!( op::MonoidUnion, w::GBVector, A::GBMatOrTranspose; mask = nothing, accum = nothing, desc = nothing ) - _, mask, accum, desc = _handlectx(op, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) A, desc, _ = _handletranspose(A, desc, nothing) op = getoperator(op, eltype(w)) accum = getoperator(accum, eltype(w)) @@ -20,7 +20,7 @@ function Base.reduce( accum = nothing, desc = nothing ) - _, mask, accum, desc = _handlectx(op, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) if typeout === nothing typeout = eltype(A) end @@ -57,7 +57,7 @@ function Base.reduce( accum = nothing, desc = nothing ) - _, _, accum, desc = _handlectx(op, nothing, accum, desc) + accum, desc = _handlenothings(accum, desc) if typeout === nothing typeout = eltype(v) end diff --git a/src/operations/select.jl b/src/operations/select.jl index 20802d4c..161a8916 100644 --- a/src/operations/select.jl +++ b/src/operations/select.jl @@ -6,9 +6,9 @@ function select!( thunk::Union{GBScalar, Nothing, Number} = nothing; mask = nothing, accum = nothing, - desc::Descriptor = nothing + desc = nothing ) - _, mask, accum, desc = _handlectx(op, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) thunk === nothing && (thunk = C_NULL) A, desc, _ = _handletranspose(A, desc) accum = getoperator(accum, eltype(C)) @@ -55,7 +55,7 @@ function select( accum = nothing, desc = nothing ) - _, mask, accum, desc = _handlectx(op, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) C = similar(A) select!(op, C, A, thunk; accum, mask, desc) return C diff --git a/src/operations/transpose.jl b/src/operations/transpose.jl index f376c995..7f3afaf1 100644 --- a/src/operations/transpose.jl +++ b/src/operations/transpose.jl @@ -18,7 +18,7 @@ function gbtranspose!( C::GBMatrix, A::GBMatOrTranspose; mask = nothing, accum = nothing, desc = nothing ) - _, mask, accum, desc = _handlectx(nothing, mask, accum, desc) + mask, accum, desc = _handlenothings(mask, accum, desc) if A isa Transpose && desc.input1 == Descriptors.TRANSPOSE throw(ArgumentError("Cannot have A isa Transpose and desc.input1 = Descriptors.TRANSPOSE.")) elseif A isa Transpose @@ -107,9 +107,12 @@ end function _handletranspose( A::GBArray, - desc::Union{Descriptor, Nothing} = nothing, + desc::Union{Descriptor, Nothing, Ptr{Nothing}} = nothing, B::Union{GBArray, Nothing} = nothing ) + if desc == C_NULL + desc = Descriptors.NULL + end if A isa Transpose desc = desc + Descriptors.T0 A = A.parent diff --git a/src/with.jl b/src/with.jl deleted file mode 100644 index e9e669e9..00000000 --- a/src/with.jl +++ /dev/null @@ -1,16 +0,0 @@ -function with(f; op = nothing, mask = nothing, accum = nothing, desc = nothing) - ctxargs = [] - if op !== nothing - push!(ctxargs, ctxop => op) - end - if mask !== nothing - push!(ctxargs, ctxmask => mask) - end - if accum !== nothing - push!(ctxargs, ctxaccum => accum) - end - if desc !== nothing - push!(ctxargs, ctxdesc => desc) - end - with_context(f, ctxargs...) -end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..3f8449e9 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,11 @@ +[deps] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +ChainRulesTestUtils = "0.7" +FiniteDifferences = "0.12" diff --git a/test/chainrules/chainrulesutils.jl b/test/chainrules/chainrulesutils.jl index e6f9a8bd..9f9c3780 100644 --- a/test/chainrules/chainrulesutils.jl +++ b/test/chainrules/chainrulesutils.jl @@ -1,4 +1,3 @@ -using FiniteDifferences function test_to_vec(x::T; check_inferred=true) where {T} check_inferred && @inferred FiniteDifferences.to_vec(x) x_vec, back = FiniteDifferences.to_vec(x) diff --git a/test/chainrules/ewiserules.jl b/test/chainrules/ewiserules.jl index f9d60905..0c963532 100644 --- a/test/chainrules/ewiserules.jl +++ b/test/chainrules/ewiserules.jl @@ -4,14 +4,14 @@ #dense first Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) X = GBMatrix(rand(-10.0:0.05:10.0, 10)) - test_frule(eadd, X, Y; check_inferred=false) - test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_rrule(eadd, X, Y; check_inferred=false) - test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_frule(emul, X, Y; check_inferred=false) - test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - test_rrule(emul, X, Y; check_inferred=false) - test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_frule(eadd, X, Y) + test_frule(eadd, X, Y, BinaryOps.PLUS) + test_rrule(eadd, X, Y) + test_rrule(eadd, X, Y, BinaryOps.PLUS) + test_frule(emul, X, Y) + test_frule(emul, X, Y, BinaryOps.TIMES) + test_rrule(emul, X, Y) + test_rrule(emul, X, Y, BinaryOps.TIMES) end end @@ -19,14 +19,14 @@ @testset "Arithmetic Semiring" begin Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) X = GBMatrix(sprand(10, 0.5)) - test_frule(eadd, X, Y; check_inferred=false) - test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_rrule(eadd, X, Y; check_inferred=false) - test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_frule(emul, X, Y; check_inferred=false) - test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - test_rrule(emul, X, Y; check_inferred=false) - test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_frule(eadd, X, Y) + test_frule(eadd, X, Y, BinaryOps.PLUS) + test_rrule(eadd, X, Y) + test_rrule(eadd, X, Y, BinaryOps.PLUS) + test_frule(emul, X, Y) + test_frule(emul, X, Y, BinaryOps.TIMES) + test_rrule(emul, X, Y) + test_rrule(emul, X, Y, BinaryOps.TIMES) end end end diff --git a/test/runtests.jl b/test/runtests.jl index 6b6d2cb4..2ba529de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using SparseArrays using Test using Random using ChainRulesTestUtils +using FiniteDifferences Random.seed!(1) function include_test(path) @@ -17,5 +18,4 @@ println("Testing SuiteSparseGraphBLAS.jl") include_test("operations.jl") include_test("chainrules/chainrulesutils.jl") include_test("chainrules/mulrules.jl") - include_test("chainrules/mulrules.jl") end From 248000971f1428caf5a176a1915a080d2f2cabc4 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sun, 11 Jul 2021 14:51:36 -0400 Subject: [PATCH 3/3] Accidentally removed FiniteDifferences --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index b4c0b81e..b52954a0 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"