Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed May 28, 2022
1 parent 681a27d commit 2d7ca99
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -727,15 +727,18 @@ function Base.show(io::IO, d::Scale)
return print(io, ")")
end

function Scale(dims::Tuple{Vararg{Integer}}, activation=identity; init_weight=glorot_uniform,
function Scale(dims::Tuple{Vararg{Integer}}, activation=identity;
init_weight=glorot_uniform,
init_bias=zeros32, bias::Bool=true)
activation = NNlib.fast_act(activation)
return Scale{bias, typeof(activation), typeof(dims), typeof(init_weight),
typeof(init_bias)}(activation, dims, init_weight, init_bias)
end

Scale(s1::Integer, s23::Integer...; _act = identity, kw...) = Scale(tuple(s1, s23...), _act; kw...)
Scale(size_act...; kw...) = Scale(size_act[1:end-1]...; _act = size_act[end], kw...)
function Scale(s1::Integer, s23::Integer...; _act=identity, kw...)
Scale(tuple(s1, s23...), _act; kw...)
end
Scale(size_act...; kw...) = Scale(size_act[1:(end - 1)]...; _act=size_act[end], kw...)

function initialparameters(rng::AbstractRNG, d::Scale{true})
return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...))
Expand Down
3 changes: 2 additions & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ end
end == [2.0 3.0; 4.0 5.0]

@test begin
layer = Scale(2, tanh; bias = false, init_weight=zeros)
layer = Scale(2, tanh; bias=false, init_weight=zeros)
first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...))
end == zeros(2, 2)
end
end

0 comments on commit 2d7ca99

Please sign in to comment.