Skip to content

Commit

Permalink
Merge pull request #342 from JuliaDiff/ox/crtup
Browse files Browse the repository at this point in the history
Update to new version of ChainRulesTestUtils and FiniteDifferences
  • Loading branch information
oxinabox authored Jan 13, 2021
2 parents 424cff3 + fef6f64 commit f540c44
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 15 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.45"
version = "0.7.46"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,11 +12,11 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.21"
ChainRulesTestUtils = "0.5.5"
ChainRulesCore = "0.9.25"
ChainRulesTestUtils = "0.5, 0.6.1"
Compat = "3"
FiniteDifferences = "0.11.4"
Reexport = "0.2, 1.0"
FiniteDifferences = "0.11, 0.12"
Reexport = "0.2, 1"
Requires = "0.5.2, 1"
julia = "1"

Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ function rrule(
::typeof(gemm), tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatrix{T}
) where T<:BlasFloat
C = gemm(tA, tB, α, A, B)
function gemv_pullback(C̄)
function gemm_pullback(C̄)
β = one(T)
if uppercase(tA) === 'N'
if uppercase(tB) === 'N'
Expand Down Expand Up @@ -251,7 +251,7 @@ function rrule(
end
return (NO_FIELDS, DoesNotExist(), DoesNotExist(), @thunk(dot(C, C̄) / α'), ∂A, ∂B)
end
return C, gemv_pullback
return C, gemm_pullback
end

function rrule(
Expand Down
2 changes: 1 addition & 1 deletion test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ const FASTABLE_AST = quote
@testset "Multivariate" begin
@testset "sincos(x::$T)" for T in (Float64, ComplexF64)
x, Δx, x̄ = randn(T, 3)
Δz = (randn(T), randn(T))
Δz = Composite{Tuple{T,T}}(randn(T), randn(T))

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
Expand Down
9 changes: 6 additions & 3 deletions test/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@
(tB, nothing),
(α, randn(T)),
(A, randn(T, size(A))),
(B, randn(T, size(B))),
(B, randn(T, size(B)));
check_inferred=false,
)

rrule_test(
Expand All @@ -103,7 +104,8 @@
(tA, nothing),
(tB, nothing),
(A, randn(T, size(A))),
(B, randn(T, size(B))),
(B, randn(T, size(B)));
check_inferred=false,
)
end
end
Expand All @@ -121,7 +123,8 @@
(t, nothing),
(α, randn(T)),
(A, randn(T, size(A))),
(x, randn(T, size(x))),
(x, randn(T, size(x)));
check_inferred=false,
)
end
end
Expand Down
10 changes: 8 additions & 2 deletions test/rulesets/LinearAlgebra/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@
@testset "rrule" begin
rrule_test(norm, ȳ, (x, x̄))
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
rrule_test(norm, ȳ, (MT(x), MT(x̄)))
# we don't check inference on older julia versions. Improvements to
# inference mean on 1.5+ it works, and that is good enough
rrule_test(norm, ȳ, (MT(x), MT(x̄)); check_inferred=VERSION>=v"1.5")
end
@test extern(rrule(norm, zero(x))[2](ȳ)[2]) zero(x)
@test rrule(norm, x)[2](Zero())[2] isa Zero
Expand Down Expand Up @@ -95,7 +97,11 @@

rrule_test(fnorm, ȳ, (x, x̄), (p, p̄); kwargs...)
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
rrule_test(fnorm, ȳ, (MT(x), MT(x̄)), (p, p̄); kwargs...)
rrule_test(
fnorm, ȳ, (MT(x), MT(x̄)), (p, p̄);
#Don't check inference on old julia, what matters is that works on new
check_inferred=VERSION>=v"1.5", kwargs...
)
end
@test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) zero(x)
@test rrule(fnorm, x, p)[2](Zero())[2] isa Zero
Expand Down
14 changes: 12 additions & 2 deletions test/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
@test ∂Ω_ad ∂Ω_fd
end
@testset "rrule" begin
# on old versions of julia this combination doesn't infer but we don't care as
# it infers fine on modern versions.
check_inferred = !(VERSION <= v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian)

x = randn(T, N, N)
∂x = randn(T, N, N)
ΔΩ = randn(T, N, N)
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
rrule_test(
SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing);
check_inferred = check_inferred
)
end
@testset "back(::Diagonal)" begin
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
rrule_test(
SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing);
check_inferred = check_inferred
)
end
end
end
Expand Down

2 comments on commit f540c44

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/27931

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.46 -m "<description of version>" f540c44cf35ab76b1cadac7286d18ee02a13bcd4
git push origin v0.7.46

Please sign in to comment.