Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Zeros with dimensions #1402

Closed
wants to merge 17 commits into from
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,9 @@ to the constructor's keyword `bias=bias`.
* `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted.
"""
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : Zeros()
bias ? fill!(similar(weights, dims...), 0) : Zeros(dims...)
end

function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
bias
Expand Down
194 changes: 169 additions & 25 deletions src/zeros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,183 @@ Useful to turn bias off for a forward pass of a layer.
julia> bias_less_conv = Conv((2,2), 1=>3; bias = false)
Conv((2, 2), 1=>3)

julia> params(bias_less_conv) |> length
1

julia> bias_less_conv.bias
Flux.Zeros()
```
"""
struct Zeros end
# To allow for things like Dense(10, 2, initb = Zeros)
Zeros(args...) = Zeros()
mutable struct Zeros{T,N} <: AbstractArray{T,N}
dims::NTuple{N,Int}
end

Zeros(::Type{T}, dims...) where T = Zeros{T,length(dims)}(dims)
Zeros(dims...) = Zeros(Bool, dims...)

Base.reshape(x::Zeros{T}, dims::Union{Colon,Int}...) where T = Zeros(T, Base._reshape_uncolon(x, dims)...)

function Base.getindex(z::Zeros{T}, args...) where T
Base.checkbounds(z, args...)
zero(T)
end

Base.collect(x::Zeros{T}) where T = zeros(T, x.dims...)

Base.size(xs::Zeros) = xs.dims
Base.copyto!(a::Zeros, b::Zeros) = b

# Base.print_array(io::IO, z::Zeros{T}) where T = print(io, "Zeros object with size $(z.dims)")

Flux.CUDA.Adapt.adapt(to, x::Zeros) = x

@adjoint reshape(xs::Zeros{T}, dims...) where T =
reshape(xs, dims...), _ -> nothing

# Define basic ops
for f in (:+, :-)
@eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros)
@assert size(a) == size(b) throw(DimensionMismatch("dimensions must match"))
a
end
end

+(a::Zeros, b::AbstractArray) = b + a
-(a::Zeros, b::AbstractArray) = -b + a

Base.copy(xs::Zeros{T,N}) where {T,N} = xs

for op in (:+, :-)
@eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros)
bs = Broadcast.broadcast_shape(size(a), size(b))
size(a) == bs && return a
sz = similar(a, bs)
sz .= a
end
end
broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a)
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a)

Base.reshape(x::Zeros, dims...) = x
# a * b
function *(a::Flux.Zeros, b::AbstractMatrix{<: Number})
sa = size(a)
sb = size(b)
@assert sa[2] == sb[1] throw(DimensionMismatch("dimensions must match"))
zero(b)
end
*(a::AbstractMatrix{<: Number}, b::Zeros) = b * a

+(::Zeros, b::AbstractArray) = b
+(a::AbstractArray, ::Zeros) = a
+(a::Zeros, ::Zeros) = a
function broadcasted(::typeof(*), a::AbstractArray{T}, b::Zeros) where {T}
bs = Broadcast.broadcast_shape(size(a), size(b))
fill!(similar(a, bs), zero(T))
end
broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = b .* a

-(::Zeros, b::AbstractArray) = -b
-(a::AbstractArray, ::Zeros) = a
-(a::Zeros, ::Zeros) = a
# Adjoints

# Some opportunities to avoid scalar indexing, intermediaries
# grad( a $op b )
for op in (:+, :-, :*)
@eval @adjoint function $op(a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M}
$op(a, b), Δ -> begin
(Δ, nothing)
end
end

if op === :-
continue
end
@eval @adjoint function $op(a::Zeros, b::AbstractArray)
$op(a, b), Δ -> begin
@show size(a), size(b)
(nothing, Δ)
end
end
end
@adjoint function -(a::Zeros, b::AbstractArray)
a - b, Δ -> (nothing, -Δ)
end


# grad( broadcast($op, a, b) )
for op in (:+, :-, :*)
@eval @adjoint function Base.broadcasted(::typeof($op), a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M}
Base.broadcasted($op, a, b), Δ -> begin
dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...)
da = dims == Tuple(1:N) ? Δ : M == N ? broadcast($op, Δ, b) : dropdims(sum(Δ, dims = dims), dims = dims)
(nothing, da, nothing)
end
end

if op === :-
continue
end
@eval @adjoint function Base.broadcasted(::typeof($op), a::Zeros{T, N}, b::AbstractArray{S, M}) where {T <: Number, S <: Number, M, N}
Base.broadcasted($op, a, b), Δ -> begin
dims = N > M ? tuple(setdiff(1:N, 1:M)...) : tuple(setdiff(1:M, 1:N)...)
db = dims == Tuple(1:M) ? (Δ .* a) : M == N ? broadcast($op, a, Δ) : dropdims(sum(Δ .* a, dims = dims), dims = dims)
(nothing, nothing, db)
end
end
end


@adjoint function Base.broadcasted(::typeof(-), a::Zeros{T, N}, b::AbstractArray{S, M}) where {T <: Number, S <: Number, M, N}
a .- b, Δ -> begin
dims = N > M ? tuple(setdiff(1:N, 1:M)...) : tuple(setdiff(1:M, 1:N)...)
da = dims == Tuple(1:M) ? Δ : M == N ? Δ .- a : dropdims(sum(Δ, dims = dims), dims = dims)
(nothing, nothing, -da)
end
end

# /
_div(a, sa, f) = fill!(similar(a, sa), f)
function /(a::Zeros, b::AbstractMatrix{T}) where T
sa = size(a)
sb = size(b)
@assert sa[end] == sb[end] throw(DimensionMismatch())
_div(b, size(b), zero(T))
end
function /(a::AbstractMatrix, b::Zeros)
sa = size(a)
sb = size(b)
@assert sa[end] == sb[end] throw(DimensionMismatch())
_div(a, size(a), Inf)
end
function broadcasted(::typeof(/), a::Zeros, b::AbstractArray{T}) where T
bs = Broadcast.broadcast_shape(size(a), size(b))
_div(b, bs, zero(T))
end
function broadcasted(::typeof(/), a::AbstractArray, b::Zeros)
bs = Broadcast.broadcast_shape(size(a), size(b))
_div(a, bs, Inf)
end

# grad( a / b)
@adjoint function /(a::Zeros, b::AbstractArray)
a / b, Δ -> (nothing, zero(Δ))
end
@adjoint function /(a::AbstractArray, b::Zeros)
a / b, Δ -> (zero(Δ), nothing)
end

# grad( broadcast(/, a, b) )
@adjoint function broadcasted(::typeof(/), a::Zeros{<: Number}, b::AbstractArray{<: Number}) # where T <: Number
sa, sb = size(a), size(b)
T = eltype(b)
a ./ b, Δ -> (nothing, nothing, _div(b, Broadcast.broadcast_shape(sa, sb), zero(T)))
end
@adjoint function broadcasted(::typeof(/), a::AbstractArray{<: Number}, b::Zeros{<: Number})
sa, sb = size(a), size(b)
T = eltype(b)
a ./ b, Δ -> (nothing, _div(a, sa, Inf), nothing)
end

Base.sum(z::Zeros{T}) where T = zero(T)

# Some opportunities to avoid scalar indexing/ intermediaries
# Since it replicates a little of what we expect Base to do,
# it should be possible to remove in the future, but for now,
# these help with performance.
broadcasted(::typeof(+), a::AbstractArray, b::Zeros) = a
broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = b
broadcasted(::typeof(-), a::AbstractArray, b::Zeros) = a
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = -b
# Need adjoints for these or else the gradient w.r.t to the non-Zeros arg will be nothing as well
@adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = zero(a), _ -> (nothing, zero(a), nothing)
@adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b))
@adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b))

# Pass-through for layer constructors
create_bias(weights::AbstractArray, bias::Flux.Zeros, dims::Integer...) = bias
broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a
broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b
broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a
broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b

broadcasted(::typeof(conj), z::Zeros) = z
21 changes: 21 additions & 0 deletions test/cuda/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,27 @@ using Zygote: pullback
@info "Testing GPU Support"
CUDA.allowscalar(false)

function gpu_gradtest(f, args...)
args_gpu = gpu.(args)

l_cpu, back_cpu = pullback((x...) -> f(x...), args...)
g_cpu = back_cpu(1f0)[1]

l_gpu, back_gpu = pullback((x...) -> f(x...), args_gpu...)
g_gpu = back_gpu(1f0)[1]

@test l_cpu ≈ l_gpu rtol=1e-4 atol=1e-4
@test g_gpu isa CuArray
@test g_cpu ≈ collect(g_gpu) rtol=1e-4 atol=1e-4
end

@testset "Moving Zeros to GPU" begin
z = Flux.Zeros()
z2 = Flux.Zeros(3,3)
@test z === gpu(z)
@test z2 === gpu(z2)
end

include("test_utils.jl")
include("cuda.jl")
include("losses.jl")
Expand Down
65 changes: 51 additions & 14 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,21 +247,58 @@ end
end

@testset "Zeros" begin
m = Dense(3,2; bias=false)

m = Dense(3, 2; bias = false)
@test f64(m).bias === m.bias === Zeros()
@test f32(m).bias === m.bias === Zeros()

@testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3))
@testset "Gradients for broadcasted $op with sizes $s" for op in (+, -, *), s in ((1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros()
Z = Zeros(s...)
a = ones(3,3)
b = zeros(3,3)
b′ = Zeros(3,3)


@testset "Basic operations" begin
a = rand(3,3)
b = zeros(3,3)
bz = Zeros(3,3)

for op in (+, -)
@test op(a, b) == op(a, bz)
end

for op in (+, -)
gs = gradient((a, b) -> sum(op(a, b)), a, b)
gsz = gradient((a, b) -> sum(op(a, b)), a, bz)
@test gs[1] == gsz[1]
@test gsz[2] === nothing
end

# Check with broadcasting
b = zeros(3,3,3)
bz = Zeros(3,3,3)

for op in (+, -)
@test broadcast(op, a, b) == broadcast(op, a, bz)
end

for op in (+, -)
gs = gradient((a,b) -> sum(broadcast(op, a, b)), a, b)
gsz = gradient((a,b) -> sum(broadcast(op, a, b)), a, bz)
@test gs[1] == gsz[1]
@test gsz[2] === nothing
end
end

@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
gfun(args...) = gradient((x, y) -> sum(op.(x, y)), args...)
g = gfun(o, z)
@test gfun(o, Z) == (g[1], nothing)

g = gfun(z, o)
g = gfun(z, o)
@test gfun(Z, o) == (nothing, g[2])
end

Expand All @@ -271,19 +308,19 @@ end

gres = gfun(o, Z)
@test gres[o] == g[o]
@test Z ∉ gres.params
# @test Z ∉ gres.params

g = gfun(z, o)
gres = gfun(Z, o)
@test gres[o] == g[o]
@test Z ∉ gres.params
# @test Z ∉ gres.params
end
end

@testset "Gradients for broadcasted / with sizes $s" for s in ((1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros() # Only defined for 0-dim
Z = Zeros(s...) # Only defined for 0-dim

@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(x ./ y), args...)
Expand All @@ -297,14 +334,14 @@ end
g = gfun(z, o)
gres = gfun(Z, o)
@test gres[o] == g[o]
@test Z ∉ gres.params
# @test Z ∉ gres.params
end
end

@testset "Gradients for $op with sizes $s" for op in (+,-), s in (tuple(), (1,), (2,3))
@testset "Gradients for $op with sizes $s" for op in (+, -), s in (tuple(), (1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros()
Z = Zeros(s...)


@testset "Explicit" begin
Expand All @@ -322,12 +359,12 @@ end
g = gfun(o, z)
gres = gfun(o, Z)
@test gres[o] == g[o]
@test Z ∉ gres.params
# @test Z ∉ gres.params

g = gfun(z, o)
gres = gfun(Z, o)
@test gres[o] == g[o]
@test Z ∉ gres.params
# @test Z ∉ gres.params
end
end
end
Expand Down Expand Up @@ -368,7 +405,7 @@ end
dl(4, 3, bias)
)

nobias(n) = Zeros()
nobias(n) = Zeros(n)
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.weight == l2.weight
@test l1.bias == l2.bias
Expand Down