diff --git a/Project.toml b/Project.toml index b50964182..6e5693620 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.23" +version = "0.7.24" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 0af358f7b..d34d9d753 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -8,7 +8,8 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) F = svd(X) function svd_pullback(Ȳ::Composite) - ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) + # svd_rev does a lot of linear algebra, it is is efficient to unthunk before + ∂X = svd_rev(F, unthunk(Ȳ.U), unthunk(Ȳ.S), unthunk(Ȳ.V)) return (NO_FIELDS, ∂X) end return F, svd_pullback diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index e589b3db4..cdd3cab45 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -13,8 +13,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @test dself1 === NO_FIELDS @test dp === DoesNotExist() - ΔF = unthunk(dF) - dself2, dX = dX_pullback(ΔF) + dself2, dX = dX_pullback(dF) @test dself2 === NO_FIELDS X̄_ad = unthunk(dX) X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)) @@ -27,6 +26,26 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo end end + @testset "Thunked inputs" begin + X = randn(4, 3) + F, dX_pullback = rrule(svd, X) + for p in [:U, :S, :V] + Y, dF_pullback = rrule(getproperty, F, p) + Ȳ = randn(size(Y)...) + + _, dF_unthunked, _ = dF_pullback(Ȳ) + + @assert !(getproperty(dF_unthunked, p) isa AbstractThunk) + dF_thunked = map(f->Thunk(()->f), dF_unthunked) + @assert getproperty(dF_thunked, p) isa AbstractThunk + + dself_thunked, dX_thunked = dX_pullback(dF_thunked) + dself_unthunked, dX_unthunked = dX_pullback(dF_unthunked) + @test dself_thunked == dself_unthunked + @test dX_thunked == dX_unthunked + end + end + @testset "+" begin X = [1.0 2.0; 3.0 4.0; 5.0 6.0] F, dX_pullback = rrule(svd, X)