From ebe424beb7635b3db71d35a79cfb89bcefdbf76d Mon Sep 17 00:00:00 2001 From: Stijn de Waele Date: Mon, 2 Dec 2019 23:26:23 -0500 Subject: [PATCH 1/4] matrix exp(A) returns real-valued adjoint for real-valued A --- src/lib/array.jl | 3 ++- test/gradcheck.jl | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 378110af2..ab8909418 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -485,7 +485,8 @@ end X = _pairdiffquotmat(exp, n, w, ew, ew, ew) V = E.vectors VF = factorize(V) - Ā = (V * ((VF \ F̄' * V) .* X) / VF)' + Āc = (V * ((VF \ F̄' * V) .* X) / VF)' + Ā = eltype(A) <: Real ? real(Āc) : Āc return (Ā,) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index a069ff207..0ccfb0fd0 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -566,6 +566,19 @@ end end end end + + @testset "real-valued" + A = [ 0.0 1.0 0.0 + 0.0 0.0 1.0 + -4.34 -18.31 -0.43] + _,B = Zygote.pullback(exp,A) + Ȳ = rand(MersenneTwister(347392),3,3) + @test isreal(B(Ȳ)[1]) + # Works when `exp` adjoint returns a real-valued array + x = [1.0] + f(A,x) = exp(A*x[1]) + @test gradtest(f,A,x) + end end _hermsymtype(::Type{<:Symmetric}) = Symmetric From 9acd17910d8a7761287f047cb634e72cbd9aa9e8 Mon Sep 17 00:00:00 2001 From: sdewaele <14310676+sdewaele@users.noreply.github.com> Date: Tue, 3 Dec 2019 13:31:05 -0500 Subject: [PATCH 2/4] Update src/lib/array.jl Co-Authored-By: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index ab8909418..a881eb091 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -486,7 +486,7 @@ end V = E.vectors VF = factorize(V) Āc = (V * ((VF \ F̄' * V) .* X) / VF)' - Ā = eltype(A) <: Real ? real(Āc) : Āc + Ā = eltype(A) <: Real && eltype(F̄) <: Real ? real(Āc) : Āc return (Ā,) end From 17a36b0a1e1a9fd2934ac31a8606edef0f814ab2 Mon Sep 17 00:00:00 2001 From: Stijn de Waele Date: Tue, 3 Dec 2019 21:49:39 -0500 Subject: [PATCH 3/4] fixed isreal test for real-valued exp pullback --- test/gradcheck.jl | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 0ccfb0fd0..4457fa83b 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -565,19 +565,12 @@ end end end end - end - - @testset "real-valued" A = [ 0.0 1.0 0.0 0.0 0.0 1.0 - -4.34 -18.31 -0.43] - _,B = Zygote.pullback(exp,A) - Ȳ = rand(MersenneTwister(347392),3,3) - @test isreal(B(Ȳ)[1]) - # Works when `exp` adjoint returns a real-valued array - x = [1.0] - f(A,x) = exp(A*x[1]) - @test gradtest(f,A,x) + -4.34 -18.31 -0.43] + _,back = Zygote.pullback(exp,A) + Ȳ = rand(3,3) + @test isreal(back(Ȳ)[1]) end end From 2c22210a068603dd4e544f67fb92222e214ef118 Mon Sep 17 00:00:00 2001 From: Stijn de Waele Date: Sat, 14 Dec 2019 19:39:32 -0500 Subject: [PATCH 4/4] use isreal(M) instead of eltype(M)<: Real --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index a881eb091..c8cfa44f3 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -486,7 +486,7 @@ end V = E.vectors VF = factorize(V) Āc = (V * ((VF \ F̄' * V) .* X) / VF)' - Ā = eltype(A) <: Real && eltype(F̄) <: Real ? real(Āc) : Āc + Ā = isreal(A) && isreal(F̄) ? real(Āc) : Āc return (Ā,) end