From bebd65265876ae2871f120ab2250b6853dc3daf8 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 12 Jan 2021 16:33:36 +0000 Subject: [PATCH 1/8] Ensure sincos type stable by using fixed version of ChainRulesCore --- .gitignore | 1 + Project.toml | 10 +++++----- test/rulesets/Base/fastmath_able.jl | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 597fc936a..c8a0f62d3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ docs/build docs/site .idea/* dev/* +Project.toml diff --git a/Project.toml b/Project.toml index ffb057613..1fd367c39 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.45" +version = "0.7.46" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,11 +12,11 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.9.21" -ChainRulesTestUtils = "0.5.5" +ChainRulesCore = "0.9.25" +ChainRulesTestUtils = "0.5, 0.6" Compat = "3" -FiniteDifferences = "0.11.4" -Reexport = "0.2, 1.0" +FiniteDifferences = "0.11, 0.12" +Reexport = "0.2, 1" Requires = "0.5.2, 1" julia = "1" diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index e83737a35..c0365bf78 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -60,7 +60,7 @@ const FASTABLE_AST = quote @testset "Multivariate" begin @testset "sincos(x::$T)" for T in (Float64, ComplexF64) x, Δx, x̄ = randn(T, 3) - Δz = (randn(T), randn(T)) + Δz = Composite{Tuple{T,T}}(randn(T), randn(T)) frule_test(sincos, (x, Δx)) rrule_test(sincos, Δz, (x, x̄)) From e5a49783630ebb46b07d8ed4d09990ceec50deba Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 12 Jan 2021 18:43:14 +0000 Subject: [PATCH 2/8] don't ignore Project.toml --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index c8a0f62d3..597fc936a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,3 @@ docs/build docs/site .idea/* dev/* -Project.toml From 90e9e7eed7aa2f863651c37011509a333b686a1b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 12 Jan 2021 19:39:30 +0000 Subject: [PATCH 3/8] Don't check inferred on BLAS rules --- src/rulesets/LinearAlgebra/blas.jl | 6 +++--- test/rulesets/LinearAlgebra/blas.jl | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 0f740c768..18ac417ea 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -155,11 +155,11 @@ end ##### `BLAS.gemm` ##### -function rrule( +@inline function rrule( ::typeof(gemm), tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatrix{T} ) where T<:BlasFloat C = gemm(tA, tB, α, A, B) - function gemv_pullback(C̄) + function gemm_pullback(C̄) β = one(T) if uppercase(tA) === 'N' if uppercase(tB) === 'N' @@ -251,7 +251,7 @@ function rrule( end return (NO_FIELDS, DoesNotExist(), DoesNotExist(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) end - return C, gemv_pullback + return C, gemm_pullback end function rrule( diff --git a/test/rulesets/LinearAlgebra/blas.jl b/test/rulesets/LinearAlgebra/blas.jl index b61ec45c5..18bb912e7 100644 --- a/test/rulesets/LinearAlgebra/blas.jl +++ b/test/rulesets/LinearAlgebra/blas.jl @@ -94,7 +94,8 @@ (tB, nothing), (α, randn(T)), (A, randn(T, size(A))), - (B, randn(T, size(B))), + (B, randn(T, size(B))); + check_inferred=false, ) rrule_test( @@ -103,7 +104,8 @@ (tA, nothing), (tB, nothing), (A, randn(T, size(A))), - (B, randn(T, size(B))), + (B, randn(T, size(B))); + check_inferred=false, ) end end @@ -121,7 +123,8 @@ (t, nothing), (α, randn(T)), (A, randn(T, size(A))), - (x, randn(T, size(x))), + (x, randn(T, size(x))); + check_inferred=false, ) end end From ff4e057680aa20c47bf54692cee801b2c680f4b5 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 12 Jan 2021 19:40:22 +0000 Subject: [PATCH 4/8] remove extra inline added --- src/rulesets/LinearAlgebra/blas.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 18ac417ea..7fcfba18a 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -155,7 +155,7 @@ end ##### `BLAS.gemm` ##### -@inline function rrule( +function rrule( ::typeof(gemm), tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatrix{T} ) where T<:BlasFloat C = gemm(tA, tB, α, A, B) From 74946b92430f4e4eee77807385dc2592ddc9609c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 13 Jan 2021 11:31:04 +0000 Subject: [PATCH 5/8] don't check inference for norm on old julia versions --- test/rulesets/LinearAlgebra/norm.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 1b0f9cac5..350fd0d43 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -58,7 +58,9 @@ @testset "rrule" begin rrule_test(norm, ȳ, (x, x̄)) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - rrule_test(norm, ȳ, (MT(x), MT(x̄))) + # we don't check inference on older julia versions. Improvements to + # inference mean on 1.5+ it works, and that is good enough + rrule_test(norm, ȳ, (MT(x), MT(x̄)); check_inferred=VERSION>=v"1.5") end @test extern(rrule(norm, zero(x))[2](ȳ)[2]) ≈ zero(x) @test rrule(norm, x)[2](Zero())[2] isa Zero From 7f29dfefde6a7706303342593eec001fb97fed2f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 13 Jan 2021 16:45:48 +0000 Subject: [PATCH 6/8] Update Project.toml Co-authored-by: Seth Axen --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1fd367c39..6e1e51e54 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.9.25" -ChainRulesTestUtils = "0.5, 0.6" +ChainRulesTestUtils = "0.5, 0.6.1" Compat = "3" FiniteDifferences = "0.11, 0.12" Reexport = "0.2, 1" From de19b72be67796769b4dec5c8f576422946b38bf Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 13 Jan 2021 19:19:40 +0000 Subject: [PATCH 7/8] stop testing inference more pre1.5 on norms --- test/rulesets/LinearAlgebra/norm.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 350fd0d43..9e354fd33 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -97,7 +97,11 @@ rrule_test(fnorm, ȳ, (x, x̄), (p, p̄); kwargs...) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - rrule_test(fnorm, ȳ, (MT(x), MT(x̄)), (p, p̄); kwargs...) + rrule_test( + fnorm, ȳ, (MT(x), MT(x̄)), (p, p̄); + #Don't check inference on old julia, what matters is that works on new + check_inferred=VERSION>=v"1.5", kwargs... + ) end @test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) ≈ zero(x) @test rrule(fnorm, x, p)[2](Zero())[2] isa Zero From fef6f64766fb27f005040eafe9e5f6b13cb4a193 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 13 Jan 2021 19:19:49 +0000 Subject: [PATCH 8/8] stop testing inference more pre1.5 on Hermian constructor --- test/rulesets/LinearAlgebra/symmetric.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 6d5cbdfce..cb2976ad6 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -16,14 +16,24 @@ @test ∂Ω_ad ≈ ∂Ω_fd end @testset "rrule" begin + # on old versions of julia this combination doesn't infer but we don't care as + # it infers fine on modern versions. + check_inferred = !(VERSION <= v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian) + x = randn(T, N, N) ∂x = randn(T, N, N) ΔΩ = randn(T, N, N) @testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular) - rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing)) + rrule_test( + SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing); + check_inferred = check_inferred + ) end @testset "back(::Diagonal)" begin - rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing)) + rrule_test( + SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing); + check_inferred = check_inferred + ) end end end