From e94c178390be66684a781ed2b8318b05c925a00e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 19 Nov 2020 12:28:56 +0530 Subject: [PATCH 01/13] correctness tests --- test/utils.jl | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index bbd514c00a..08474a77a4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -138,12 +138,47 @@ end o = ones(s) z = zeros(s) Z = Zeros() + a = ones(3,3) + b = zeros(3,3) + b′ = Zeros(3,3) + + + @testset "Basic operations" begin + a = rand(3,3) + b = zeros(3,3) + bz = Zeros(3,3) + + for op in (+, -) + @test op(a,b) == op(a, bz) + end + + for op in (+, -) + gs = gradient((a,b) -> sum(op(a, b)), a, b) + gsz = gradient((a,b) -> sum(op(a, b)), a, bz) + @test gs[1] == gsz[1] + @test gsz[2] === nothing + end + + # Check with broadcasting + b = zeros(3,3,3) + bz = Zeros(3,3,3) + + for op in (+, -) + @test broadcast(op, a,b) == broadcast(op, a, bz) + end + + for op in (+, -) + gs = gradient((a,b) -> sum(broadcast(op, a, b)), a, b) + gsz = gradient((a,b) -> sum(broadcast(op, a, b)), a, bz) + @test gs[1] == gsz[1] + @test gsz[2] === nothing + end + end @testset "Explicit" begin gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) g = gfun(o, z) @test gfun(o, Z) == (g[1], nothing) - g = gfun(z, o) @test gfun(Z, o) == (nothing, g[2]) end @@ -272,4 +307,4 @@ end testdense(re(p), bt) end end -end \ No newline at end of file +end From 48880933c20d1ed0c83c25f4c8f495dfb9e6db79 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 19 Nov 2020 12:29:59 +0530 Subject: [PATCH 02/13] add zeros array implementation --- src/zeros.jl | 82 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 18 deletions(-) diff --git a/src/zeros.jl b/src/zeros.jl index fef9d1862e..93bc583301 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -21,29 +21,75 @@ julia> bias_less_conv.bias Flux.Zeros() ``` """ -struct Zeros end -# To allow for things like Dense(10, 2, initb = Zeros) -Zeros(args...) = Zeros() +struct Zeros{T,N} <: AbstractArray{T,N} + dims::NTuple{N,Int} +end -Base.reshape(x::Zeros, dims...) = x +Zeros(::Type{T}, dims...) where T = Zeros{T,length(dims)}(dims) +Zeros(dims...) = Zeros(Bool, dims...) -+(::Zeros, b::AbstractArray) = b -+(a::AbstractArray, ::Zeros) = a -+(a::Zeros, ::Zeros) = a +Base.reshape(x::Zeros{T}, dims::Union{Colon,Int}...) where T = Zeros(T, Base._reshape_uncolon(x, dims)...) +Base.getindex(z::Zeros, args...) = error("Calling getindex on Zeros object, materialize to normal array or check for correctness") +Base.collect(x::Zeros{T}) where T = zeros(T, x.dims...) --(::Zeros, b::AbstractArray) = -b --(a::AbstractArray, ::Zeros) = a --(a::Zeros, ::Zeros) = a +Base.size(xs::Zeros) = xs.dims +Base.copyto!(a::Zeros, b::Zeros) = b + +Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) + +@adjoint reshape(xs::Zeros{T}, dims...) where T = + reshape(xs, dims...), _ -> nothing + +# Define basic ops +for f in (:+, :-) + @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) + @assert size(a) == size(b) throw(DimensionMismatch("dimensions must match")) + a + end +end + ++(a::Zeros, b::AbstractArray) = b + a +-(a::Zeros, b::AbstractArray) = -b + a + +Base.copy(xs::Zeros{T,N}) where {T,N} = xs + +for op in (:+, :-) + @eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros) + bs = Broadcast.broadcast_shape(size(a), size(b)) + size(a) == bs && return a + sz = similar(a, bs) + sz .= a + end +end + +broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) +broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) + +for op in (:+, :-) + @eval @adjoint function Base.broadcasted(::typeof($op), a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} + Base.broadcasted($op, a, b), Δ -> begin + dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...) + da = dims == tuple(1:N...) ? Δ : dropdims(sum(Δ, dims = dims), dims = dims) + (nothing, transpose(da), nothing) + end + end +end + +Base.sum(z::Zeros{T}) where T = zero(T) + +for op in (:+, :-) + @eval @adjoint function $op(a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} + $op(a, b), Δ -> begin + (Δ, nothing) + end + end +end # Some opportunities to avoid scalar indexing, intermediaries # Since it replicates a little of what we expect Base to do, # it should be possible to remove in the future, but for now, # these help with performance. -broadcasted(::typeof(+), a::AbstractArray, b::Zeros) = a -broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = b -broadcasted(::typeof(-), a::AbstractArray, b::Zeros) = a -broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = -b -# Need adjoints for these or else the gradient w.r.t to the non-Zeros arg will be nothing as well -@adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = zero(a), _ -> (nothing, zero(a), nothing) -@adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) -@adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) +broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b +broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b From 12b5444c9c238e8b0cf4e18104f2684103073407 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 19 Nov 2020 14:16:31 +0530 Subject: [PATCH 03/13] remove extra transpose --- src/zeros.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeros.jl b/src/zeros.jl index 93bc583301..8c754b4431 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -70,7 +70,7 @@ for op in (:+, :-) Base.broadcasted($op, a, b), Δ -> begin dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...) da = dims == tuple(1:N...) ? Δ : dropdims(sum(Δ, dims = dims), dims = dims) - (nothing, transpose(da), nothing) + (nothing, da, nothing) end end end From e67db6f45f8794bf8b6e5c0a84d5539988c48800 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 19 Nov 2020 19:24:45 +0530 Subject: [PATCH 04/13] add cuda compat --- src/zeros.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/zeros.jl b/src/zeros.jl index 8c754b4431..dfeb83070f 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -21,7 +21,7 @@ julia> bias_less_conv.bias Flux.Zeros() ``` """ -struct Zeros{T,N} <: AbstractArray{T,N} +mutable struct Zeros{T,N} <: AbstractArray{T,N} dims::NTuple{N,Int} end @@ -30,12 +30,15 @@ Zeros(dims...) = Zeros(Bool, dims...) Base.reshape(x::Zeros{T}, dims::Union{Colon,Int}...) where T = Zeros(T, Base._reshape_uncolon(x, dims)...) Base.getindex(z::Zeros, args...) = error("Calling getindex on Zeros object, materialize to normal array or check for correctness") +# Base.getindex(z::Zeros{T}, args...) where T = zero(T) Base.collect(x::Zeros{T}) where T = zeros(T, x.dims...) Base.size(xs::Zeros) = xs.dims Base.copyto!(a::Zeros, b::Zeros) = b -Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) +# Base.print_array(io::IO, z::Zeros{T}) where T = print(io, "Zeros object with size $(z.dims)") + +Flux.CUDA.Adapt.adapt(to, x::Zeros) = x @adjoint reshape(xs::Zeros{T}, dims...) where T = reshape(xs, dims...), _ -> nothing From a727256808cc87472538e7b9de88fd92f60f4af2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 19 Nov 2020 19:31:47 +0530 Subject: [PATCH 05/13] check Zeros movement tests --- test/cuda/runtests.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/cuda/runtests.jl b/test/cuda/runtests.jl index 8c19caed59..4468070635 100644 --- a/test/cuda/runtests.jl +++ b/test/cuda/runtests.jl @@ -17,6 +17,11 @@ function gpu_gradtest(f, args...) @test l_cpu ≈ l_gpu rtol=1e-4 atol=1e-4 @test g_gpu isa CuArray @test g_cpu ≈ collect(g_gpu) rtol=1e-4 atol=1e-4 + + z = Flux.Zeros() + z2 = Flux.Zeros(3,3) + @test z === gpu(z) + @test z2 === gpu(z2) end @@ -30,4 +35,4 @@ if CUDA.has_cudnn() include("curnn.jl") else @warn "CUDNN unavailable, not testing GPU DNN support" -end \ No newline at end of file +end From d80daa3cf4860e9cac487ca610a9b66e5f5dd8d2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 28 Jun 2021 18:00:14 +0530 Subject: [PATCH 06/13] cleanup some tests --- test/utils.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index a3fe88d5ae..c64767ec0a 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -236,11 +236,11 @@ end end @testset "Zeros" begin - m = Dense(3,2; bias=false) - @test f64(m).b === m.b === Zeros() - @test f32(m).b === m.b === Zeros() + m = Dense(3, 2; bias = false) + @test f64(m).b === m.b + @test f32(m).b === m.b - @testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3)) + @testset "Gradients for broadcasted $op with sizes $s" for op in (+, -, *), s in ((1,), (2,3)) o = ones(s) z = zeros(s) Z = Zeros() @@ -255,12 +255,12 @@ end bz = Zeros(3,3) for op in (+, -) - @test op(a,b) == op(a, bz) + @test op(a, b) == op(a, bz) end for op in (+, -) - gs = gradient((a,b) -> sum(op(a, b)), a, b) - gsz = gradient((a,b) -> sum(op(a, b)), a, bz) + gs = gradient((a, b) -> sum(op(a, b)), a, b) + gsz = gradient((a, b) -> sum(op(a, b)), a, bz) @test gs[1] == gsz[1] @test gsz[2] === nothing end @@ -282,11 +282,11 @@ end end @testset "Explicit" begin - gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) + gfun(args...) = gradient((x, y) -> sum(op.(x, y)), args...) g = gfun(o, z) @test gfun(o, Z) == (g[1], nothing) - g = gfun(z, o) + g = gfun(z, o) @test gfun(Z, o) == (nothing, g[2]) end @@ -296,12 +296,12 @@ end gres = gfun(o, Z) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params g = gfun(z, o) gres = gfun(Z, o) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params end end @@ -322,11 +322,11 @@ end g = gfun(z, o) gres = gfun(Z, o) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params end end - @testset "Gradients for $op with sizes $s" for op in (+,-), s in (tuple(), (1,), (2,3)) + @testset "Gradients for $op with sizes $s" for op in (+, -), s in (tuple(), (1,), (2,3)) o = ones(s) z = zeros(s) Z = Zeros() @@ -347,12 +347,12 @@ end g = gfun(o, z) gres = gfun(o, Z) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params g = gfun(z, o) gres = gfun(Z, o) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params end end end From 0ed9755c0c35b2e5195c0f8baf1aa3adf74b1bbd Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 30 Jun 2021 13:51:07 +0530 Subject: [PATCH 07/13] fuller coverage of the API --- src/zeros.jl | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/zeros.jl b/src/zeros.jl index dfeb83070f..e5d09b5dcc 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -14,9 +14,6 @@ Useful to turn bias off for a forward pass of a layer. julia> bias_less_conv = Conv((2,2), 1=>3; bias = false) Conv((2, 2), 1=>3) -julia> params(bias_less_conv) |> length -1 - julia> bias_less_conv.bias Flux.Zeros() ``` @@ -43,6 +40,8 @@ Flux.CUDA.Adapt.adapt(to, x::Zeros) = x @adjoint reshape(xs::Zeros{T}, dims...) where T = reshape(xs, dims...), _ -> nothing +# @adjoint Zeros(args...) = Zeros(args...), _ -> nothing + # Define basic ops for f in (:+, :-) @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) @@ -65,22 +64,36 @@ for op in (:+, :-) end end +function broadcasted(::typeof(*), a::AbstractArray{T}, b::Zeros) where {T} + bs = Broadcast.broadcast_shape(size(a), size(b)) + fill!(similar(a, bs), zero(T)) +end +broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = b .* a + broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) -for op in (:+, :-) +for op in (:+, :-, :*) @eval @adjoint function Base.broadcasted(::typeof($op), a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} Base.broadcasted($op, a, b), Δ -> begin dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...) - da = dims == tuple(1:N...) ? Δ : dropdims(sum(Δ, dims = dims), dims = dims) + da = dims == Tuple(1:N) ? Δ : dropdims(sum(Δ, dims = dims), dims = dims) (nothing, da, nothing) end end + + @eval @adjoint function Base.broadcasted(::typeof($op), a::Zeros{<:Any, N}, b::AbstractArray{<: Number, M}) where {M, N} + a .* b, Δ -> begin + dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...) + da = dims == Tuple(1:N) ? Δ : dropdims(sum(Δ, dims = dims), dims = dims) + (nothing, nothing, da) + end + end end Base.sum(z::Zeros{T}) where T = zero(T) -for op in (:+, :-) +for op in (:+, :-, :*) @eval @adjoint function $op(a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} $op(a, b), Δ -> begin (Δ, nothing) @@ -96,3 +109,9 @@ broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b +# broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) +# broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(b) + +broadcasted(::typeof(conj), z::Zeros) = z + +@adjoint broadcasted(::typeof(*), a::Zeros{S,0}, b::AbstractArray{T}) where {S, T <: Number} = a .* b, Δ -> (nothing, nothing, Δ) From c6ff709b5c3659360cd82adcb057ab62e2f524f4 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 30 Jun 2021 13:51:26 +0530 Subject: [PATCH 08/13] cleanup some tests --- test/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index c64767ec0a..9f25468429 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -270,7 +270,7 @@ end bz = Zeros(3,3,3) for op in (+, -) - @test broadcast(op, a,b) == broadcast(op, a, bz) + @test broadcast(op, a, b) == broadcast(op, a, bz) end for op in (+, -) @@ -308,7 +308,7 @@ end @testset "Gradients for broadcasted / with sizes $s" for s in ((1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros() # Only defined for 0-dim + Z = Zeros(s...) # Only defined for 0-dim @testset "Explicit" begin gfun(args...) = gradient((x, y) -> sum(x ./ y), args...) @@ -329,7 +329,7 @@ end @testset "Gradients for $op with sizes $s" for op in (+, -), s in (tuple(), (1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros() + Z = Zeros(s...) @testset "Explicit" begin @@ -374,7 +374,7 @@ end dl(4, 3, bias) ) - nobias(n) = Zeros() + nobias(n) = Zeros(n) testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) @test l1.W == l2.W @test l1.b == l2.b From 41cee36bcf5b7b3f39e3af2815c5c91b32d41ba1 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 30 Jun 2021 16:04:18 +0530 Subject: [PATCH 09/13] add dims test --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 9f25468429..c699039a04 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -243,7 +243,7 @@ end @testset "Gradients for broadcasted $op with sizes $s" for op in (+, -, *), s in ((1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros() + Z = Zeros(s) a = ones(3,3) b = zeros(3,3) b′ = Zeros(3,3) From b4f52e82d2fa4432025019180c4b900e00476cbc Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 30 Jun 2021 16:29:23 +0530 Subject: [PATCH 10/13] add adjoint for non-broadcasted ops --- src/zeros.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/zeros.jl b/src/zeros.jl index e5d09b5dcc..9fd05427c6 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -99,6 +99,10 @@ for op in (:+, :-, :*) (Δ, nothing) end end + + @eval @adjoint function $op(a::Zeros, b::AbstractArray) + $op(a, b), Δ -> (nothing, Δ) + end end # Some opportunities to avoid scalar indexing, intermediaries From 7f925c729a5d9ccc20750e3017d07c73af7b112f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 30 Jun 2021 16:45:32 +0530 Subject: [PATCH 11/13] cleanup --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index c699039a04..c26aeb99eb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -243,7 +243,7 @@ end @testset "Gradients for broadcasted $op with sizes $s" for op in (+, -, *), s in ((1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros(s) + Z = Zeros(s...) a = ones(3,3) b = zeros(3,3) b′ = Zeros(3,3) From 06c4505205173a526f7df52decb396088d68426c Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 15 Jul 2021 00:20:04 +0530 Subject: [PATCH 12/13] add fns for / also --- src/zeros.jl | 127 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 101 insertions(+), 26 deletions(-) diff --git a/src/zeros.jl b/src/zeros.jl index 9fd05427c6..1c4733e9a8 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -26,8 +26,12 @@ Zeros(::Type{T}, dims...) where T = Zeros{T,length(dims)}(dims) Zeros(dims...) = Zeros(Bool, dims...) Base.reshape(x::Zeros{T}, dims::Union{Colon,Int}...) where T = Zeros(T, Base._reshape_uncolon(x, dims)...) -Base.getindex(z::Zeros, args...) = error("Calling getindex on Zeros object, materialize to normal array or check for correctness") -# Base.getindex(z::Zeros{T}, args...) where T = zero(T) + +function Base.getindex(z::Zeros{T}, args...) where T + Base.checkbounds(z, args...) + zero(T) +end + Base.collect(x::Zeros{T}) where T = zeros(T, x.dims...) Base.size(xs::Zeros) = xs.dims @@ -40,8 +44,6 @@ Flux.CUDA.Adapt.adapt(to, x::Zeros) = x @adjoint reshape(xs::Zeros{T}, dims...) where T = reshape(xs, dims...), _ -> nothing -# @adjoint Zeros(args...) = Zeros(args...), _ -> nothing - # Define basic ops for f in (:+, :-) @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) @@ -63,6 +65,17 @@ for op in (:+, :-) sz .= a end end +broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) +broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) + +# a * b +function *(a::Flux.Zeros, b::AbstractMatrix{<: Number}) + sa = size(a) + sb = size(b) + @assert sa[2] == sb[1] throw(DimensionMismatch("dimensions must match")) + zero(b) +end +*(a::AbstractMatrix{<: Number}, b::Zeros) = b * a function broadcasted(::typeof(*), a::AbstractArray{T}, b::Zeros) where {T} bs = Broadcast.broadcast_shape(size(a), size(b)) @@ -70,42 +83,108 @@ function broadcasted(::typeof(*), a::AbstractArray{T}, b::Zeros) where {T} end broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = b .* a -broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) -broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) +# Adjoints + +# grad( a $op b ) +for op in (:+, :-, :*) + @eval @adjoint function $op(a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} + $op(a, b), Δ -> begin + (Δ, nothing) + end + end + + if op === :- + continue + end + @eval @adjoint function $op(a::Zeros, b::AbstractArray) + $op(a, b), Δ -> begin + @show size(a), size(b) + (nothing, Δ) + end + end +end +@adjoint function -(a::Zeros, b::AbstractArray) + a - b, Δ -> (nothing, -Δ) +end + +# grad( broadcast($op, a, b) ) for op in (:+, :-, :*) @eval @adjoint function Base.broadcasted(::typeof($op), a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} Base.broadcasted($op, a, b), Δ -> begin dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...) - da = dims == Tuple(1:N) ? Δ : dropdims(sum(Δ, dims = dims), dims = dims) + da = dims == Tuple(1:N) ? Δ : M == N ? broadcast($op, Δ, b) : dropdims(sum(Δ, dims = dims), dims = dims) (nothing, da, nothing) end end - @eval @adjoint function Base.broadcasted(::typeof($op), a::Zeros{<:Any, N}, b::AbstractArray{<: Number, M}) where {M, N} - a .* b, Δ -> begin - dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...) - da = dims == Tuple(1:N) ? Δ : dropdims(sum(Δ, dims = dims), dims = dims) - (nothing, nothing, da) + if op === :- + continue + end + @eval @adjoint function Base.broadcasted(::typeof($op), a::Zeros{T, N}, b::AbstractArray{S, M}) where {T <: Number, S <: Number, M, N} + Base.broadcasted($op, a, b), Δ -> begin + dims = N > M ? tuple(setdiff(1:N, 1:M)...) : tuple(setdiff(1:M, 1:N)...) + db = dims == Tuple(1:M) ? (Δ .* a) : M == N ? broadcast($op, a, Δ) : dropdims(sum(Δ .* a, dims = dims), dims = dims) + (nothing, nothing, db) end end end -Base.sum(z::Zeros{T}) where T = zero(T) -for op in (:+, :-, :*) - @eval @adjoint function $op(a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} - $op(a, b), Δ -> begin - (Δ, nothing) - end +@adjoint function Base.broadcasted(::typeof(-), a::Zeros{T, N}, b::AbstractArray{S, M}) where {T <: Number, S <: Number, M, N} + a .- b, Δ -> begin + dims = N > M ? tuple(setdiff(1:N, 1:M)...) : tuple(setdiff(1:M, 1:N)...) + da = dims == Tuple(1:M) ? Δ : M == N ? Δ .- a : dropdims(sum(Δ, dims = dims), dims = dims) + (nothing, nothing, -da) end +end - @eval @adjoint function $op(a::Zeros, b::AbstractArray) - $op(a, b), Δ -> (nothing, Δ) - end +# / +_div(a, sa, f) = fill!(similar(a, sa), f) +function /(a::Zeros, b::AbstractMatrix{T}) where T + sa = size(a) + sb = size(b) + @assert sa[end] == sb[end] throw(DimensionMismatch()) + _div(b, size(b), zero(T)) +end +function /(a::AbstractMatrix, b::Zeros) + sa = size(a) + sb = size(b) + @assert sa[end] == sb[end] throw(DimensionMismatch()) + _div(a, size(a), Inf) +end +function broadcasted(::typeof(/), a::Zeros, b::AbstractArray{T}) where T + bs = Broadcast.broadcast_shape(size(a), size(b)) + _div(b, bs, zero(T)) +end +function broadcasted(::typeof(/), a::AbstractArray, b::Zeros) + bs = Broadcast.broadcast_shape(size(a), size(b)) + _div(a, bs, Inf) +end + +# grad( a / b) +@adjoint function /(a::Zeros, b::AbstractArray) + a / b, Δ -> (nothing, zero(Δ)) +end +@adjoint function /(a::AbstractArray, b::Zeros) + a / b, Δ -> (zero(Δ), nothing) end -# Some opportunities to avoid scalar indexing, intermediaries +# grad( broadcast(/, a, b) ) +@adjoint function broadcasted(::typeof(/), a::Zeros{<: Number}, b::AbstractArray{<: Number}) # where T <: Number + sa, sb = size(a), size(b) + T = eltype(b) + a ./ b, Δ -> (nothing, nothing, _div(b, Broadcast.broadcast_shape(sa, sb), zero(T))) +end +@adjoint function broadcasted(::typeof(/), a::AbstractArray{<: Number}, b::Zeros{<: Number}) + sa, sb = size(a), size(b) + T = eltype(b) + a ./ b, Δ -> (nothing, _div(a, sa, Inf), nothing) +end + +Base.sum(z::Zeros{T}) where T = zero(T) + +# Some opportunities to avoid scalar indexing/ intermediaries # Since it replicates a little of what we expect Base to do, # it should be possible to remove in the future, but for now, # these help with performance. @@ -113,9 +192,5 @@ broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b -# broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) -# broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(b) broadcasted(::typeof(conj), z::Zeros) = z - -@adjoint broadcasted(::typeof(*), a::Zeros{S,0}, b::AbstractArray{T}) where {S, T <: Number} = a .* b, Δ -> (nothing, nothing, Δ) From e3323ed514cd961e9889bac519b793f8d8bdc9d1 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 15 Jul 2021 17:27:54 +0530 Subject: [PATCH 13/13] use zeros consistently --- src/utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 06b2bb01b0..d4a74c5137 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -386,8 +386,9 @@ to the constructor's keyword `bias=bias`. * `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted. """ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) - bias ? fill!(similar(weights, dims...), 0) : Zeros() + bias ? fill!(similar(weights, dims...), 0) : Zeros(dims...) end + function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) bias