From 7481f5511ef8d46517186ba6675da1fa1004fa05 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 14:48:46 +0100 Subject: [PATCH 01/22] Bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a5b16dc76..52fd73087 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.4.0" +version = "1.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 1802186d823c52a965018113494820176ce57477 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 14:48:59 +0100 Subject: [PATCH 02/22] Add undef Array constructor --- src/rulesets/Base/array.jl | 2 ++ test/rulesets/Base/array.jl | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index ffd97464c..f147e13d9 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -1,3 +1,5 @@ +ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...) + ##### ##### `reshape` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index cb9a56940..26530584e 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,3 +1,8 @@ +@testset "constructors" begin + test_rrule(Array{Float64, 1}, undef, 5) + test_rrule(Array{Float32, 3}, undef, 5, 4, 3) +end + @testset "reshape" begin test_rrule(reshape, rand(4, 5), (2, 10)) test_rrule(reshape, rand(4, 5), 2, 10) From 880c50f8b6a04f711742506b03a1d615af507b38 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 15:00:40 +0100 Subject: [PATCH 03/22] construct Array from existing AbstractArray --- src/rulesets/Base/array.jl | 10 ++++++++++ test/rulesets/Base/array.jl | 10 ++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index f147e13d9..c016303bd 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -1,5 +1,15 @@ +##### +##### constructors +##### + ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...) +function rrule(::Type{T}, x::AbstractArray) where {T<:Array} + project_x = ProjectTo(x) + Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return T(x), Array_pullback +end + ##### ##### `reshape` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 26530584e..25a3398da 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,6 +1,12 @@ @testset "constructors" begin - test_rrule(Array{Float64, 1}, undef, 5) - test_rrule(Array{Float32, 3}, undef, 5, 4, 3) + @testset "undef" begin + test_rrule(Array{Float64, 1}, undef, 5) + test_rrule(Array{Float32, 3}, undef, 5, 4, 3) + end + @testset "from existing array" begin + test_rrule(Array, randn(2, 5)) + test_rrule(Array, Diagonal(randn(5))) + end end @testset "reshape" begin From 7a0375c01129ae4696dc1823081478ba5e726677 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 16:36:28 +0100 Subject: [PATCH 04/22] vect implementation --- src/rulesets/Base/array.jl | 25 +++++++++++++++++++++++++ test/rulesets/Base/array.jl | 6 ++++++ 2 files changed, 31 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index c016303bd..f6002c0a2 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -10,6 +10,31 @@ function rrule(::Type{T}, x::AbstractArray) where {T<:Array} return T(x), Array_pullback end +##### +##### `vect` +##### + +@non_differentiable Base.vect() + +# Don't worry about projection here. The data passes straight through, so if a cotangent has +# the wrong type for some reason, it must be the fault of another rule somewhere. +function rrule(::typeof(Base.vect), X...) + function vect_pullback(ȳ) + X̄ = ntuple(n -> ȳ[n], length(X)) + return (NoTangent(), X̄...) + end + return Base.vect(X...), vect_pullback +end + +# # Edge case: Numbers get promoted to other numbers, so we need to project. +# function rrule(::typeof(Base.vect), X::Number...) +# project +# function vect_pullback(ȳ) +# X̄ = ntuple(n -> ) +# end +# return Base.vect(X...), vect_pullback +# end + ##### ##### `reshape` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 25a3398da..f1a299d50 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -9,6 +9,12 @@ end end +@testset "vect" begin + test_rrule(Base.vect) + test_rrule(Base.vect, 5.0, 4.0, 3.0) + test_rrule(Base.vect, randn(2, 2), randn(3, 3); check_inferred=false) +end + @testset "reshape" begin test_rrule(reshape, rand(4, 5), (2, 10)) test_rrule(reshape, rand(4, 5), 2, 10) From 5c061fba02d6b4f08ddf3e097ab1da229a938b61 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 17:44:28 +0100 Subject: [PATCH 05/22] Bump precision --- test/rulesets/Base/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index f1a299d50..0ccbdc82e 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,7 +1,7 @@ @testset "constructors" begin @testset "undef" begin test_rrule(Array{Float64, 1}, undef, 5) - test_rrule(Array{Float32, 3}, undef, 5, 4, 3) + test_rrule(Array{Float64, 3}, undef, 5, 4, 3) end @testset "from existing array" begin test_rrule(Array, randn(2, 5)) From 21b6bed8b848ddb837f90998a3fc2c3b424051f6 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 1 Aug 2021 17:44:49 +0100 Subject: [PATCH 06/22] Additional tests Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- test/rulesets/Base/array.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index f1a299d50..f88f7f173 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -6,6 +6,9 @@ @testset "from existing array" begin test_rrule(Array, randn(2, 5)) test_rrule(Array, Diagonal(randn(5))) + test_rrule(Matrix, Diagonal(randn(5))) + test_rrule(Matrix, transpose(randn(4))) + test_rrule(Array{ComplexF64}, randn(3)) end end From 1a0ead227e75c174e54bc60734bee48c200c9af0 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 19:44:13 +0100 Subject: [PATCH 07/22] Fix undef tests --- test/rulesets/Base/array.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 7eb118720..8a54d57ff 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,7 +1,13 @@ @testset "constructors" begin + + # We can't use test_rrule here (as it's currently implemented) because the elements of + # the array have arbitrary values. The only thing we can do is ensure that we're getting + # `ZeroTangent`s back, and that the forwards pass produces the correct thing still. @testset "undef" begin - test_rrule(Array{Float64, 1}, undef, 5) - test_rrule(Array{Float64, 3}, undef, 5, 4, 3) + val, pullback = rrule(Array{Float64}, undef, 5) + @test size(val) == (5, ) + @test val isa Array{Float64, 1} + @test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent()) end @testset "from existing array" begin test_rrule(Array, randn(2, 5)) From cac5b4f73848a3e47a1271574d468e02fc0237c7 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 21:45:48 +0100 Subject: [PATCH 08/22] Constraint vect implementation --- src/rulesets/Base/array.jl | 14 +++----------- test/rulesets/Base/array.jl | 9 +++++++-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index f6002c0a2..47063e526 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -18,23 +18,15 @@ end # Don't worry about projection here. The data passes straight through, so if a cotangent has # the wrong type for some reason, it must be the fault of another rule somewhere. -function rrule(::typeof(Base.vect), X...) +function rrule(::typeof(Base.vect), X::T...) where {T} + l = length(X) function vect_pullback(ȳ) - X̄ = ntuple(n -> ȳ[n], length(X)) + X̄ = ntuple(n -> ȳ[n], l) return (NoTangent(), X̄...) end return Base.vect(X...), vect_pullback end -# # Edge case: Numbers get promoted to other numbers, so we need to project. -# function rrule(::typeof(Base.vect), X::Number...) -# project -# function vect_pullback(ȳ) -# X̄ = ntuple(n -> ) -# end -# return Base.vect(X...), vect_pullback -# end - ##### ##### `reshape` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 8a54d57ff..059b0ec59 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -20,8 +20,13 @@ end @testset "vect" begin test_rrule(Base.vect) - test_rrule(Base.vect, 5.0, 4.0, 3.0) - test_rrule(Base.vect, randn(2, 2), randn(3, 3); check_inferred=false) + @testset "homogeneous type" begin + test_rrule(Base.vect, 5.0, 4.0, 3.0) + test_rrule(Base.vect, randn(2, 2), randn(3, 3)) + end + @testset "inhomogeneous type" begin + test_rrule(Base.vect, 5.0, 4) + end end @testset "reshape" begin From 0d18c81c17e8965223246e797c6da91afdb8796b Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 22:27:47 +0100 Subject: [PATCH 09/22] Add float-only test --- test/rulesets/Base/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 059b0ec59..974fc874f 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -25,7 +25,7 @@ end test_rrule(Base.vect, randn(2, 2), randn(3, 3)) end @testset "inhomogeneous type" begin - test_rrule(Base.vect, 5.0, 4) + test_rrule(Base.vect, 5.0, 3f0; atol=1e-6, rtol=1e-6) # tolerance due to Float32. end end From 6564952379a1c377f63dcf1b470909ec4e2fb30e Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 22:28:13 +0100 Subject: [PATCH 10/22] Link to non_differentiable tests issue --- test/rulesets/Base/array.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 974fc874f..adeea1a29 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -3,6 +3,7 @@ # We can't use test_rrule here (as it's currently implemented) because the elements of # the array have arbitrary values. The only thing we can do is ensure that we're getting # `ZeroTangent`s back, and that the forwards pass produces the correct thing still. + # Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202 @testset "undef" begin val, pullback = rrule(Array{Float64}, undef, 5) @test size(val) == (5, ) From 060825a92145e166fff36c509eafcf9733f4c9c7 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 1 Aug 2021 22:34:38 +0100 Subject: [PATCH 11/22] Type-stable `vect` implementation Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/rulesets/Base/array.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 47063e526..10106b7bc 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -18,11 +18,9 @@ end # Don't worry about projection here. The data passes straight through, so if a cotangent has # the wrong type for some reason, it must be the fault of another rule somewhere. -function rrule(::typeof(Base.vect), X::T...) where {T} - l = length(X) +function rrule(::typeof(Base.vect), X::Vararg{T,N}) where {T,N} function vect_pullback(ȳ) - X̄ = ntuple(n -> ȳ[n], l) - return (NoTangent(), X̄...) + return (NoTangent(), NTuple{N}(ȳ)...) end return Base.vect(X...), vect_pullback end From fca9087c3fed1877a1b7655451aafd550496de7c Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 22:35:55 +0100 Subject: [PATCH 12/22] type stable vect pullback --- src/rulesets/Base/array.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 47063e526..32805330d 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -18,10 +18,20 @@ end # Don't worry about projection here. The data passes straight through, so if a cotangent has # the wrong type for some reason, it must be the fault of another rule somewhere. -function rrule(::typeof(Base.vect), X::T...) where {T} - l = length(X) +function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} function vect_pullback(ȳ) - X̄ = ntuple(n -> ȳ[n], l) + X̄ = ntuple(n -> ȳ[n], N) + return (NoTangent(), X̄...) + end + return Base.vect(X...), vect_pullback +end + +# Numbers need to be projected because they don't pass straight through the function. +# More generally, we would ideally project everything. +function rrule(::typeof(Base.vect), X::Vararg{Number, N}) where {N} + projects = map(ProjectTo, X) + function vect_pullback(ȳ) + X̄ = ntuple(n -> projects[n](ȳ[n]), N) return (NoTangent(), X̄...) end return Base.vect(X...), vect_pullback From f9d22b6d5680d5344e289d01e17ad731ab07884b Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 1 Aug 2021 22:43:48 +0100 Subject: [PATCH 13/22] Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/rulesets/Base/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index d8f58f956..7220f3ed1 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -28,7 +28,7 @@ end # Numbers need to be projected because they don't pass straight through the function. # More generally, we would ideally project everything. -function rrule(::typeof(Base.vect), X::Vararg{Number, N}) where {N} +function rrule(::typeof(Base.vect), X::Vararg{Union{Number,AbstractArray{<:Number}}, N}) where {N} projects = map(ProjectTo, X) function vect_pullback(ȳ) X̄ = ntuple(n -> projects[n](ȳ[n]), N) From b3bf297838c84e66c5b142f01bf458d8349ab80e Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 22:46:46 +0100 Subject: [PATCH 14/22] Test Union{Number, AbstractArray} --- test/rulesets/Base/array.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index adeea1a29..e79d3f666 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -27,6 +27,7 @@ end end @testset "inhomogeneous type" begin test_rrule(Base.vect, 5.0, 3f0; atol=1e-6, rtol=1e-6) # tolerance due to Float32. + test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) end end From 5971c589da8f8e848da924dee5d3b96ca0931b01 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 1 Aug 2021 23:35:38 +0100 Subject: [PATCH 15/22] Don't test inference below 1.6 --- test/rulesets/Base/array.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index e79d3f666..076808c84 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -26,7 +26,10 @@ end test_rrule(Base.vect, randn(2, 2), randn(3, 3)) end @testset "inhomogeneous type" begin - test_rrule(Base.vect, 5.0, 3f0; atol=1e-6, rtol=1e-6) # tolerance due to Float32. + test_rrule( + Base.vect, 5.0, 3f0; + atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6", + ) # tolerance due to Float32. test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) end end From acd82e1f93f1ed4234256801bd89fd62df96b7bc Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 2 Aug 2021 13:13:05 +0100 Subject: [PATCH 16/22] Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/rulesets/Base/array.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 7220f3ed1..e7f8e908d 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -19,10 +19,7 @@ end # Don't worry about projection here. The data passes straight through, so if a cotangent has # the wrong type for some reason, it must be the fault of another rule somewhere. function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} - function vect_pullback(ȳ) - X̄ = ntuple(n -> ȳ[n], N) - return (NoTangent(), NTuple{N}(ȳ)...) - end + vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...) return Base.vect(X...), vect_pullback end From 48e5a6703fa1055f3ac2f28b079d6de069b1dda9 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 2 Aug 2021 13:13:59 +0100 Subject: [PATCH 17/22] Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/rulesets/Base/array.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index e7f8e908d..6604d723b 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -23,8 +23,7 @@ function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} return Base.vect(X...), vect_pullback end -# Numbers need to be projected because they don't pass straight through the function. -# More generally, we would ideally project everything. +# Numbers and arrays are often promoted, to make a uniform vector; ProjectTo here reverses this function rrule(::typeof(Base.vect), X::Vararg{Union{Number,AbstractArray{<:Number}}, N}) where {N} projects = map(ProjectTo, X) function vect_pullback(ȳ) From a061b02ea7ceacce6619ab655c2e4364fe495a14 Mon Sep 17 00:00:00 2001 From: WT Date: Mon, 2 Aug 2021 13:17:49 +0100 Subject: [PATCH 18/22] Style fix --- src/rulesets/Base/array.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 6604d723b..2d92e6324 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -24,7 +24,10 @@ function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} end # Numbers and arrays are often promoted, to make a uniform vector; ProjectTo here reverses this -function rrule(::typeof(Base.vect), X::Vararg{Union{Number,AbstractArray{<:Number}}, N}) where {N} +function rrule( + ::typeof(Base.vect), + X::Vararg{Union{Number,AbstractArray{<:Number}}, N}, +) where {N} projects = map(ProjectTo, X) function vect_pullback(ȳ) X̄ = ntuple(n -> projects[n](ȳ[n]), N) From f5bb4fc5679ad0d2a455f1efa829bcb2afa2ac9a Mon Sep 17 00:00:00 2001 From: WT Date: Mon, 2 Aug 2021 13:18:18 +0100 Subject: [PATCH 19/22] Style fix --- src/rulesets/Base/array.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 2d92e6324..2648f8b9f 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -23,7 +23,8 @@ function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} return Base.vect(X...), vect_pullback end -# Numbers and arrays are often promoted, to make a uniform vector; ProjectTo here reverses this +# Numbers and arrays are often promoted, to make a uniform vector. +# ProjectTo here reverses this function rrule( ::typeof(Base.vect), X::Vararg{Union{Number,AbstractArray{<:Number}}, N}, From a3ebb8b45ae2fbc46194a75fb52500f5fe76091a Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 2 Aug 2021 13:24:19 +0100 Subject: [PATCH 20/22] Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/rulesets/Base/array.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 2648f8b9f..4b702b505 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -16,8 +16,7 @@ end @non_differentiable Base.vect() -# Don't worry about projection here. The data passes straight through, so if a cotangent has -# the wrong type for some reason, it must be the fault of another rule somewhere. +# Case of uniform type `T`: the data passes straight through, so no projection should be required. function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...) return Base.vect(X...), vect_pullback From ed60ed4283a5f035473955ace47e8720917f21a2 Mon Sep 17 00:00:00 2001 From: WT Date: Mon, 2 Aug 2021 13:24:35 +0100 Subject: [PATCH 21/22] Style fix --- src/rulesets/Base/array.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 4b702b505..928c0cf57 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -16,7 +16,8 @@ end @non_differentiable Base.vect() -# Case of uniform type `T`: the data passes straight through, so no projection should be required. +# Case of uniform type `T`: the data passes straight through, +# so no projection should be required. function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...) return Base.vect(X...), vect_pullback From 000b33f97cb195c957340f2e1940b808e5a6e60d Mon Sep 17 00:00:00 2001 From: WT Date: Mon, 2 Aug 2021 15:01:19 +0100 Subject: [PATCH 22/22] Add an extra test --- test/rulesets/Base/array.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 076808c84..8a0bf7d85 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -22,6 +22,7 @@ end @testset "vect" begin test_rrule(Base.vect) @testset "homogeneous type" begin + test_rrule(Base.vect, (5.0, ), (4.0, )) test_rrule(Base.vect, 5.0, 4.0, 3.0) test_rrule(Base.vect, randn(2, 2), randn(3, 3)) end