diff --git a/Project.toml b/Project.toml index 0cc618af0..56d086f99 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.21" +version = "0.6.22" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" ChainRules = "1.5" -ChainRulesCore = "1.1" +ChainRulesCore = "1.6" ChainRulesTestUtils = "1" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" diff --git a/README.md b/README.md index 8551bca87..6b2a6517d 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ julia> using Zygote julia> f(x) = 5x + 3 julia> f(10), f'(10) -(53, 5) +(53, 5.0) julia> @code_llvm f'(10) define i64 @"julia_#625_38792"(i64) { diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 4bf7da28a..e879af3f8 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -123,11 +123,33 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR """ @inline wrap_chainrules_input(x) = x @inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() +@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent() @inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) xp = map(wrap_chainrules_input, xs) ChainRules.Tangent{Any, typeof(xp)}(xp) end +""" + _project(x, dx) + +Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape. +Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`. +Safe to apply to arbitrary input. +""" +@inline function _project(x, dx) + wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx))) +end + +# Restore splatted arrays +_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x))) + +# Piracy: +# wrap_chainrules_input doesn't handle array of Union{Int,Nothing} +(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent() + +# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any} +(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im)) + """ ZBack{F}(back) <: Function diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index e4db33471..9dc934a49 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -68,15 +68,20 @@ julia> gradient([7, 11], 0, 1) do x, y, d p = size(x, d) sum(x.^p .+ y) end -([14.0, 22.0], 2, nothing) +([14.0, 22.0], 2.0, nothing) ``` """ function gradient(f, args...) y, back = pullback(f, args...) - return back(sensitivity(y)) + grad = back(sensitivity(y)) + isnothing(grad) ? nothing : map(_project, args, grad) end -Base.adjoint(f::Function) = x -> gradient(f, x)[1] +# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! +Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons + y, back = pullback(f, x) + back(sensitivity(y))[1] +end """ withgradient(f, args...) @@ -95,7 +100,9 @@ true """ function withgradient(f, args...) y, back = pullback(f, args...) - (val = y, grad = back(sensitivity(y))) + grad = back(sensitivity(y)) + results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) + (val=y, grad=results) end # Param-style wrappers @@ -115,9 +122,9 @@ julia> g = gradient(Params([x, y])) do Grads(...) julia> g[x] -2×3 Matrix{Int64}: - 7 70 700 - 8 80 800 +2×3 Matrix{Float64}: + 7.0 70.0 700.0 + 8.0 80.0 800.0 julia> haskey(g, z) # only x and y are parameters false @@ -144,6 +151,8 @@ Params(xs::Tuple) = Params(collect(xs)) @forward Params.order Base.iterate, Base.length, Base.getindex @forward Params.params Base.in +Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params) + function Base.union!(ps::Params, itrs...) foreach(itr -> foreach(x -> push!(ps, x), itr), itrs) return ps diff --git a/src/lib/array.jl b/src/lib/array.jl index 15b994564..9bec64b95 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -38,7 +38,7 @@ end dxv = view(dx, inds...) dxv .= accum.(dxv, _droplike(dy, dxv)) end - return (dx, map(_->nothing, inds)...) + return (_project(x, dx), map(_->nothing, inds)...) end """ diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 446e919b1..4e7a3a1cc 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -45,18 +45,19 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) end -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -trim(x::Tuple, Δ) = NTuple{length(x)}(Δ) - -unbroadcast(x::AbstractArray, x̄) = - size(x) == size(x̄) ? x̄ : - length(x) == length(x̄) ? trim(x, x̄) : - trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) - +function unbroadcast(x::AbstractArray, x̄) + N = ndims(x̄) + if length(x) == length(x̄) + _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors + else + dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) + _project(x, accum_sum(x̄; dims = dims)) + end +end unbroadcast(x::Number, x̄) = accum_sum(x̄) unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) -unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 +unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 unbroadcast(x::AbstractArray, x̄::Nothing) = nothing diff --git a/test/complex.jl b/test/complex.jl index 6a0445b85..1abd1303f 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -1,9 +1,13 @@ using Zygote, Test, LinearAlgebra +@testset "basic" begin + @test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1 @test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0 -@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ -1im -@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] == 1im +@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im +@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ 0 # projected to zero +@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] ≈ 1im +@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] ≈ 0 @test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im @test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im @@ -21,6 +25,8 @@ using Zygote, Test, LinearAlgebra @test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3)) @test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3) +end # @testset + fs_C_to_R = (real, imag, abs, @@ -81,3 +87,26 @@ fs_C_to_C_non_holomorphic = (conj, end end end + +@testset "issue 342" begin + @test Zygote.gradient(x->real(x + 2.0*im), 3.0) == (1.0,) + @test Zygote.gradient(x->imag(x + 2.0*im), 3.0) == (0.0,) +end + +@testset "issue 402" begin + A = [1,2,3.0] + y, B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A)) + bA = B_getindex(1)[1] + @test bA isa Diagonal + @test bA == [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0] +end + +@testset "issue #917" begin + function fun(v) + c = v[1:3] + v[4:6]*im + r = v[7:9] + sum(r .* abs2.(c)) # This would be calling my actual function depending on r and c + end + @test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0] +end + diff --git a/test/cuda.jl b/test/cuda.jl index 3999ace59..5cb1c8cdc 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,12 +1,20 @@ using CUDA using Zygote: Grads +using LinearAlgebra using Random: randn! CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @testset "GPU movement" begin r = rand(Float32, 3,3) - @test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2} + @test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32} + @test gradient(x -> sum(x->log(x), cu(x)), r)[1] isa Matrix + @test gradient((x,cy) -> sum(cu(x) * cy) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray + @test_skip gradient((x,cy) -> sum(cu(x[:,1])' * cy), r, cu(r))[2] isa CUDA.CuArray # generic_matmatmul! + + # Other direction: + @test_skip gradient(x -> sum(Array(x)), cu(r))[1] isa CUDA.CuArray + @test_skip gradient((x,cy) -> sum(x * Array(cy)) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray end @testset "broadcasting" begin @@ -31,10 +39,19 @@ end g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] + + # Projection: eltype preservation: + @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} + @test_skip gradient(x -> sum(x .* 5.6), a_gpu)[1] isa CUDA.CuArray{Float32} # dot(x::CuArray{Float64}, y::CuArray{Float32}) fallback + # structure restoration: + @test gradient(x -> sum(sqrt.(x)), a_gpu')[1] isa Adjoint # previously a matrix + @test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal + # non-differentiables + @test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing end @testset "sum(f, x)" begin - a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01]) + a = Float32[-1.5, -9.0, 2.4, -1.3, 0.01] a_gpu = a |> cu f(x) = sum(abs, x) @@ -42,6 +59,18 @@ end g_gpu = gradient(f, a_gpu)[1] @test g_gpu isa CuArray @test g_gpu |> collect ≈ g + + f2(x) = sum(abs2, x) # sum(abs2, x) has its own rrule + g2 = gradient(f2, a)[1] + g2_gpu = gradient(f2, a_gpu)[1] + @test g2_gpu isa CuArray + @test g2_gpu |> collect ≈ g2 + + f3(x) = sum(y->y^3, x') # anonymous function + g3 = gradient(f3, a')[1] + g3_gpu = gradient(f3, a_gpu')[1] + @test g3_gpu isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure + @test g3_gpu |> collect ≈ g3 end @testset "jacobian" begin @@ -103,5 +132,11 @@ end r = cu(rand(Float32, 3)) grads = (cu(ones(Float32, 3)), 1.f0) @test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads + + @test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[1] isa CUDA.CuArray{Float32} + @test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[2] isa Float64 # projection + + @test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order + @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} end diff --git a/test/features.jl b/test/features.jl index 8c460dc98..d683d0d94 100644 --- a/test/features.jl +++ b/test/features.jl @@ -176,9 +176,9 @@ end @test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),) -@test gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),) +@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,) -@test gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),) +@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,) struct Bar{T} a::T @@ -262,6 +262,7 @@ D(f, x) = grad(f, x)[1] @test D(x -> x*D(y -> x+y, 1), 1) == 1 @test D(x -> x*D(y -> x*y, 1), 4) == 8 +@test sin''(1.0) == -sin(1.0) @test sin'''(1.0) == -cos(1.0) f(x) = throw(DimensionMismatch("fubar")) @@ -499,6 +500,25 @@ end @test x[1] == x[2] end +@testset "splats" begin + @test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1] + @test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0) + + @test gradient(x -> max(x...), [1 2; 3 4])[1] == [0 0; 0 1] + @test gradient(x -> max(x...), [1,2,3]')[1] == [0 0 1] + + # https://github.com/FluxML/Zygote.jl/issues/599 + @test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector + + # https://github.com/FluxML/Zygote.jl/issues/866 + f866(x) = reshape(x, fill(2, 2)...) + @test gradient(x->sum(f866(x)), rand(4))[1] == [1,1,1,1] + + # https://github.com/FluxML/Zygote.jl/issues/731 + f731(x) = sum([x' * x, x...]) + @test_broken gradient(f731, ones(3)) # MethodError: no method matching +(::Tuple{Float64, Float64, Float64}, ::Vector{Float64}) +end + @testset "accumulation" begin # from https://github.com/FluxML/Zygote.jl/issues/905 function net(x1) diff --git a/test/forward/forward.jl b/test/forward/forward.jl index 3ae0f6e3a..6aa9173ef 100644 --- a/test/forward/forward.jl +++ b/test/forward/forward.jl @@ -36,7 +36,8 @@ end == 1 x end == 0 -@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1] +@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1] +@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real using LinearAlgebra diff --git a/test/gradcheck.jl b/test/gradcheck.jl index eab959ddd..af49b7697 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -177,7 +177,7 @@ end # Ensure that nothings work with non-numeric types. _, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1]) - @test back([nothing]) == ([nothing for _ in 1:3], nothing) + @test back([nothing]) == (nothing, nothing) end @testset "view" begin @@ -332,10 +332,10 @@ end @test gradient(x -> sum(log, filter(iseven, x)), 1:10) == (map(x -> iseven(x) ? 1/x : 0, 1:10),) @test gradient(x -> sum(abs2, im .+ filter(iseven, x)), 1:10) == - (map(x -> iseven(x) ? 2x+2im : 0, 1:10),) + (map(x -> iseven(x) ? 2x : 0, 1:10),) + # (map(x -> iseven(x) ? 2x+2im : 0, 1:10),) end - @testset "mean" begin @test gradtest(mean, rand(2, 3)) @@ -1157,10 +1157,10 @@ end end @testset "hvcat" begin - @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == (1,0,0,0) - @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == (0,0,1,0) - @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == (0,1,0,0) - @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == (0,0,0,1) + @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == [1,0,0,0] + @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == [0,0,1,0] + @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == [0,1,0,0] + @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == [0,0,0,1] # https://github.com/FluxML/Zygote.jl/issues/513 @test gradient(x -> hvcat((2,2),1,2,3,x)[4], 4.0) == (1.0,) end @@ -1375,10 +1375,10 @@ using Zygote: Buffer @test gs[1] ≈ map(x -> one.(x), p) @test gs[2] ≈ one.(r) - p = [rand(3,3), rand(3,3)] # redefine `p` after mutation - gs = gradient(x -> sum(pop!(x)), p) - @test length(gs[1]) == 2 - @test gs[1][1] == one.(p[1]) + # p = [rand(3,3), rand(3,3)] # redefine `p` after mutation + # gs = gradient(x -> sum(pop!(x)), p) + # @test length(gs[1]) == 2 + # @test gs[1][1] == one.(p[1]) end end @@ -1403,6 +1403,17 @@ end end @testset "AbstractFFTs" begin + + # Many of these tests check a complex gradient to a function with real input. This is now + # clamped to real by ProjectTo, but to run the old tests, use here the old gradient function: + function oldgradient(f, args...) + y, back = Zygote.pullback(f, args...) + back(Zygote.sensitivity(y)) + end + # Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests + # can be updated to use real / complex consistently. + # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58 + findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1, l=1:n2] mirrorIndex(i,N) = i - 2*max(0,i - (N>>1+1)) @@ -1415,11 +1426,11 @@ end indicateMat = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:size(X, 1), l=1:size(X,2)] # gradient of ifft(fft) must be (approximately) 1 (for various cases) - @test gradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat + @test oldgradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat # same for the inverse - @test gradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat + @test oldgradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat # same for rfft(irfft) - @test gradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat) + @test oldgradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat) # rfft isn't actually surjective, so rffft(irfft) can't really be tested this way. # the gradients are actually just evaluating the inverse transform on the @@ -1438,22 +1449,22 @@ end ((K)->(irfft(K,sizeX[1])), 1/N * rfft(indicateMat), zeros(size(X̂r)), plan_rfft(X), i, X̂r)] for (trans, solRe, solIm, P, mI, evalX) in listOfSols - @test gradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈ + @test oldgradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈ solRe - @test gradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈ + @test oldgradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈ solIm if typeof(P) <:AbstractFFTs.Plan && maximum(trans .== [fft,rfft]) - @test gradient((X)->real.(P * X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->real.(P * X)[mI, j], evalX)[1] ≈ solRe - @test gradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈ solIm elseif typeof(P) <: AbstractFFTs.Plan - @test gradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈ solRe # for whatever reason the rfft_plan doesn't handle this case well, # even though irfft does if eltype(evalX) <: Real - @test gradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈ + @test oldgradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈ solIm end end @@ -1464,47 +1475,47 @@ end x = [-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982] # check ffts for individual dimensions for trans in (fft, ifft, bfft) - @test gradient((x)->sum(abs.(trans(x))), x)[1] ≈ - gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] + @test oldgradient((x)->sum(abs.(trans(x))), x)[1] ≈ + oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] # switch sum abs order - @test gradient((x)->abs(sum((trans(x)))),x)[1] ≈ - gradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1] + @test oldgradient((x)->abs(sum((trans(x)))),x)[1] ≈ + oldgradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1] # dims parameter for the function - @test gradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈ - gradient( (x) -> sum(abs.(trans(x))), x)[1] + @test oldgradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈ + oldgradient( (x) -> sum(abs.(trans(x))), x)[1] # (1,2) should be the same as no index - @test gradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈ - gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] + @test oldgradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈ + oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] @test gradcheck(x->sum(abs.(trans(x))), x) @test gradcheck(x->sum(abs.(trans(x, 2))), x) end - @test gradient((x)->sum(abs.(rfft(x))), x)[1] ≈ - gradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1] - @test gradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈ - gradient( (x) -> sum(abs.(rfft(x))), x)[1] + @test oldgradient((x)->sum(abs.(rfft(x))), x)[1] ≈ + oldgradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1] + @test oldgradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈ + oldgradient( (x) -> sum(abs.(rfft(x))), x)[1] # Test type stability of fft x = randn(Float64,16) P = plan_fft(x) - @test typeof(gradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1} - @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1} + @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1} x = randn(Float64,16,16) - @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) - @test typeof(gradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1} - @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1} + @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} x = randn(Float32,16,16) - @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} + @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} + @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} end @testset "FillArrays" begin @@ -1668,7 +1679,7 @@ end # check that type is not unnecessarily promoted # https://github.com/FluxML/Zygote.jl/issues/663 @test gradient(norm, randn(Float32, 2, 2)) isa Tuple{Matrix{Float32}} - @test gradient(norm, randn(Float32, 2, 2), 3) isa Tuple{Matrix{Float32},Float32} + @test gradient(norm, randn(Float32, 2, 2), 3) isa Tuple{Matrix{Float32},Float64} @test gradient(norm, randn(Float32, 2, 2), 3f0) isa Tuple{Matrix{Float32},Float32} @test gradient(norm, randn(ComplexF32, 2, 2), 3.5f0) isa Tuple{Matrix{ComplexF32},Float32} diff --git a/test/structures.jl b/test/structures.jl index 37c0e246a..5a951a621 100644 --- a/test/structures.jl +++ b/test/structures.jl @@ -53,6 +53,7 @@ struct A594 x::Float64 end Y = randn(2,2) ∇ = gradient(g,X,Y) @test ∇[1] == [(x = 2.0,); (x = 2.0,)] + @test vec(∇[1]) == [(x = 2.0,); (x = 2.0,)] @test ∇[2] == [1 1; 1 1] end diff --git a/test/utils.jl b/test/utils.jl index 70a8ebd63..b6d6ed018 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -19,16 +19,22 @@ using Zygote: hessian_dual, hessian_reverse @test_throws Exception hess(identity, randn(2)) end -@testset "diagonal hessian" begin +VERSION > v"1.6-" && @testset "diagonal hessian" begin @test diaghessian(x -> x[1]*x[2]^2, [1, pi]) == ([0, 2],) - xs, y = randn(2,3), rand() - f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments - - dx, dy = diaghessian(f34, xs, y) - @test size(dx) == size(xs) - @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) - @test dy ≈ hessian(y -> f34(xs,y), y) + if VERSION > v"1.6-" + # Gradient of ^ may contain log(complex(...)), which interacts badly with Dual below Julia 1.6: + # julia> log(ForwardDiff.Dual(1,0) + 0im) # ERROR: StackOverflowError: + # https://github.com/JuliaDiff/ChainRules.jl/issues/525 + # Fixed in 1.6 by: https://github.com/JuliaLang/julia/pull/36030 + xs, y = randn(2,3), rand() + f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments + + dx, dy = diaghessian(f34, xs, y) + @test size(dx) == size(xs) + @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) + @test dy ≈ hessian(y -> f34(xs,y), y) + end zs = randn(7,13) # test chunk mode @test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD @@ -67,6 +73,7 @@ end j5 = jacobian((x,y) -> hcat(x[1], y), fill(pi), exp(1)) # zero-array @test j5[1] isa Matrix @test vec(j5[1]) == [1, 0] + @test j5[2] == [0, 1] @test_throws ArgumentError jacobian(identity, [1,2,3+im]) @test_throws ArgumentError jacobian(sum, [1,2,3+im]) # scalar, complex