Skip to content

Commit 37597b2

Browse files
author
Will Kimmerer
authored
Assorted mul rules (#31)
* Some mul rules * Some frules
1 parent 09d22f5 commit 37597b2

File tree

5 files changed

+159
-18
lines changed

5 files changed

+159
-18
lines changed

src/chainrules/mulrules.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,73 @@ function rrule(
4949
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3]
5050
return Ω, pullback
5151
end
52+
53+
54+
# PLUS_DIV:
55+
function rrule(
56+
::typeof(mul),
57+
A::GBMatOrTranspose,
58+
B::GBMatOrTranspose,
59+
::typeof(Semirings.PLUS_DIV)
60+
)
61+
function mulpullback(ΔΩ)
62+
∂A = mul(ΔΩ, one(eltype(A)) ./ B', Semirings.PLUS_TIMES; mask=A)
63+
∂B = (zero(eltype(A)) .- mul(A', ΔΩ; mask=B)) ./ (B .^ 2.)
64+
return NoTangent(), ∂A, ∂B, NoTangent()
65+
end
66+
return mul(A, B, Semirings.PLUS_DIV), mulpullback
67+
end
68+
69+
# PLUS_PLUS:
70+
function frule(
71+
(_, ΔA, ΔB, _),
72+
::typeof(mul),
73+
A::GBMatOrTranspose,
74+
B::GBMatOrTranspose,
75+
::typeof(Semirings.PLUS_PLUS)
76+
)
77+
Ω = mul(A, B, Semirings.PLUS_PLUS)
78+
∂Ω = mul(ΔA, ΔB, Semirings.PLUS_PLUS)
79+
return Ω, ∂Ω
80+
end
81+
82+
function rrule(
83+
::typeof(mul),
84+
A::GBMatOrTranspose,
85+
B::GBMatOrTranspose,
86+
::typeof(Semirings.PLUS_PLUS)
87+
)
88+
function mulpullback(ΔΩ)
89+
∂A = mul(ΔΩ, B', Semirings.PLUS_FIRST; mask=A)
90+
∂B = mul(A', ΔΩ, Semirings.PLUS_SECOND; mask=B)
91+
return NoTangent(), ∂A, ∂B, NoTangent()
92+
end
93+
return mul(A, B, Semirings.PLUS_PLUS), mulpullback
94+
end
95+
96+
# PLUS_MINUS:
97+
function frule(
98+
(_, ΔA, ΔB, _),
99+
::typeof(mul),
100+
A::GBMatOrTranspose,
101+
B::GBMatOrTranspose,
102+
::typeof(Semirings.PLUS_MINUS)
103+
)
104+
Ω = mul(A, B, Semirings.PLUS_MINUS)
105+
∂Ω = mul(ΔA, ΔB, Semirings.PLUS_MINUS)
106+
return Ω, ∂Ω
107+
end
108+
109+
function rrule(
110+
::typeof(mul),
111+
A::GBMatOrTranspose,
112+
B::GBMatOrTranspose,
113+
::typeof(Semirings.PLUS_MINUS)
114+
)
115+
function mulpullback(ΔΩ)
116+
∂A = mul(ΔΩ, B', Semirings.PLUS_FIRST; mask=A)
117+
∂B = mul(A', zero(eltype(ΔΩ)) .- ΔΩ, Semirings.PLUS_SECOND; mask=B)
118+
return NoTangent(), ∂A, ∂B, NoTangent()
119+
end
120+
return mul(A, B, Semirings.PLUS_MINUS), mulpullback
121+
end

src/descriptors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,4 @@ function _loaddescriptors()
183183
end
184184

185185
Base.show(io::IO, ::MIME"text/plain", d::Descriptor) = gxbprint(io, d)
186+
Base.print(io::IO, d::Descriptor) = gxbprint(io, d)

src/operations/map.jl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ function Base.map!(
77
mask, accum, desc = _handlenothings(mask, accum, desc)
88
op = getoperator(op, eltype(A))
99
accum = getoperator(accum, eltype(C))
10+
A, desc = _handletranspose(A, desc)
1011
if C isa GBVector && A isa GBVector
1112
libgb.GrB_Vector_apply(C, mask, accum, op, A, desc)
1213
elseif C isa GBMatrix && A isa GBMatrix
@@ -36,6 +37,7 @@ function Base.map!(
3637
mask, accum, desc = _handlenothings(mask, accum, desc)
3738
op = getoperator(op, optype(eltype(A), typeof(x)))
3839
accum = getoperator(accum, eltype(C))
40+
_, desc, A = _handletranspose(nothing, desc, A)
3941
if C isa GBVector && A isa GBVector
4042
libgb.scalarvecapply1st[optype(typeof(x), eltype(A))](C, mask, accum, op, x, A, desc)
4143
elseif C isa GBMatrix && A isa GBMatrix
@@ -66,6 +68,7 @@ function Base.map!(
6668
mask, accum, desc = _handlenothings(mask, accum, desc)
6769
op = getoperator(op, optype(eltype(A), typeof(x)))
6870
accum = getoperator(accum, eltype(C))
71+
A, desc, _ = _handletranspose(A, desc)
6972
if C isa GBVector && A isa GBVector
7073
libgb.scalarvecapply2nd[optype(typeof(x), eltype(A))](C, mask, accum, op, A, x, desc)
7174
elseif C isa GBMatrix && A isa GBMatrix
@@ -89,25 +92,41 @@ function Base.map(
8992
return map!(op, similar(A, t), A, x; mask, accum, desc)
9093
end
9194

92-
function Base.broadcasted(::typeof(+), u::GBArray, x::valid_union
93-
)
95+
function Base.broadcasted(::typeof(+), u::GBArray, x::valid_union)
9496
map(BinaryOps.PLUS, u, x)
9597
end
96-
function Base.broadcasted(
97-
::typeof(+), x::valid_union, u::GBArray
98-
)
98+
function Base.broadcasted(::typeof(+), x::valid_union, u::GBArray)
9999
map(BinaryOps.PLUS, x, u)
100100
end
101101

102-
function Base.broadcasted(::typeof(*), u::GBArray, x::valid_union
103-
)
102+
function Base.broadcasted(::typeof(-), u::GBArray, x::valid_union)
103+
map(BinaryOps.MINUS, u, x)
104+
end
105+
function Base.broadcasted(::typeof(-), x::valid_union, u::GBArray)
106+
map(BinaryOps.MINUS, x, u)
107+
end
108+
109+
function Base.broadcasted(::typeof(*), u::GBArray, x::valid_union)
104110
map(BinaryOps.TIMES, u, x)
105111
end
106-
function Base.broadcasted(::typeof(*), x::valid_union, u::GBArray
107-
)
112+
function Base.broadcasted(::typeof(*), x::valid_union, u::GBArray)
108113
map(BinaryOps.TIMES, x, u)
109114
end
110115

116+
function Base.broadcasted(::typeof(/), u::GBArray, x::valid_union)
117+
map(BinaryOps.DIV, u, x)
118+
end
119+
function Base.broadcasted(::typeof(/), x::valid_union, u::GBArray)
120+
map(BinaryOps.DIV, x, u;)
121+
end
122+
123+
function Base.broadcasted(::typeof(^), u::GBArray, x::valid_union)
124+
map(BinaryOps.POW, u, x)
125+
end
126+
function Base.broadcasted(::typeof(^), x::valid_union, u::GBArray)
127+
map(BinaryOps.POW, x, u)
128+
end
129+
111130
"""
112131
map(op::UnaryOp, A::GBArray; kwargs...)::GBArray
113132
map(op::BinaryOp, A::GBArray, x; kwargs...)::GBArray

src/operations/transpose.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ function Base.copy(v::LinearAlgebra.Transpose{<:Any, <:GBVector})
106106
end
107107

108108
function _handletranspose(
109-
A::GBArray,
109+
A::Union{GBArray, Nothing} = nothing,
110110
desc::Union{Descriptor, Nothing, Ptr{Nothing}} = nothing,
111111
B::Union{GBArray, Nothing} = nothing
112112
)

test/chainrules/mulrules.jl

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,72 @@
11
@testset "mul" begin
22
@testset "Dense" begin
3-
@testset "Arithmetic Semiring" begin
4-
M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10))
5-
Y = GBMatrix(rand(-10.0:0.05:10.0, 10))
3+
M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10))
4+
Y = GBMatrix(rand(-10.0:0.05:10.0, 10))
5+
N = GBMatrix(rand(-10.0:0.05:10.0, 10, 11))
6+
@testset "+.*" begin
67
test_frule(mul, M, Y)
78
test_frule(mul, M, Y, Semirings.PLUS_TIMES)
89
test_rrule(mul, M, Y)
910
test_rrule(mul, M, Y, Semirings.PLUS_TIMES)
11+
12+
test_frule(mul, M, N)
13+
test_frule(mul, M, N, Semirings.PLUS_TIMES)
14+
test_rrule(mul, M, N)
15+
test_rrule(mul, M, N, Semirings.PLUS_TIMES)
16+
end
17+
18+
@testset "+.÷" begin
19+
test_rrule(mul, M, Y, Semirings.PLUS_DIV)
20+
test_rrule(mul, M, N, Semirings.PLUS_DIV)
21+
end
22+
23+
@testset "+.+" begin
24+
test_frule(mul, M, Y, Semirings.PLUS_PLUS)
25+
test_frule(mul, M, N, Semirings.PLUS_PLUS)
26+
test_rrule(mul, M, Y, Semirings.PLUS_PLUS)
27+
test_rrule(mul, M, N, Semirings.PLUS_PLUS)
28+
end
29+
30+
@testset "+.-" begin
31+
test_frule(mul, M, Y, Semirings.PLUS_MINUS)
32+
test_frule(mul, M, N, Semirings.PLUS_MINUS)
33+
test_rrule(mul, M, Y, Semirings.PLUS_MINUS)
34+
test_rrule(mul, M, N, Semirings.PLUS_MINUS)
1035
end
1136
end
1237

1338
@testset "Sparse" begin
1439
M = GBMatrix(sprand(100, 10, 0.25))
15-
Y = GBMatrix(sprand(10, 0.1)) #using matrix for now until I work out transpose(v::GBVector)
16-
test_frule(mul, M, Y)
17-
test_frule(mul, M, Y, Semirings.PLUS_TIMES)
18-
test_rrule(mul, M, Y)
19-
test_rrule(mul, M, Y, Semirings.PLUS_TIMES)
40+
Y = GBMatrix(sprand(10, 0.1))
41+
N = GBMatrix(sprand(10, 75, 0.05))
42+
@testset "+.*" begin
43+
test_frule(mul, M, Y)
44+
test_frule(mul, M, Y, Semirings.PLUS_TIMES)
45+
test_rrule(mul, M, Y)
46+
test_rrule(mul, M, Y, Semirings.PLUS_TIMES)
47+
48+
test_frule(mul, M, N)
49+
test_frule(mul, M, N, Semirings.PLUS_TIMES)
50+
test_rrule(mul, M, N)
51+
test_rrule(mul, M, N, Semirings.PLUS_TIMES)
52+
end
53+
54+
@testset "+.÷" begin
55+
test_rrule(mul, M, Y, Semirings.PLUS_DIV)
56+
test_rrule(mul, M, N, Semirings.PLUS_DIV)
57+
end
58+
59+
@testset "+.+" begin
60+
test_frule(mul, M, Y, Semirings.PLUS_PLUS)
61+
test_frule(mul, M, N, Semirings.PLUS_PLUS)
62+
test_rrule(mul, M, Y, Semirings.PLUS_PLUS)
63+
test_rrule(mul, M, N, Semirings.PLUS_PLUS)
64+
end
65+
@testset "+.-" begin
66+
test_frule(mul, M, Y, Semirings.PLUS_MINUS)
67+
test_frule(mul, M, N, Semirings.PLUS_MINUS)
68+
test_rrule(mul, M, Y, Semirings.PLUS_MINUS)
69+
test_rrule(mul, M, N, Semirings.PLUS_MINUS)
70+
end
2071
end
2172
end

0 commit comments

Comments
 (0)