-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix convolution & pooling type-stability #370
Conversation
This looks really good! I believe you'll need to use BenchmarkTools for the F and B numbers too. Which Julia version was the |
We probably don't want to have small Unions be part of the dispatch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we move the syntax to match the current one? It's harder to look for the most important changes when there are unrelated changes in a PR. The more to the point the changes, the better
Re note: I believe in most cases that means that the inference is actually going fine. |
I'm using 1.7.1. But issue is present on previous versions as well. In the PR description I've posted Before: Here you can see, that @code_warntypejulia> c = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3));
julia> x = zeros(Float32, 28, 28, 1, 5);
julia> @code_warntype c(x)
MethodInstance for (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(::Array{Float32, 4})
from (c::Conv)(x::AbstractArray) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/conv.jl:163
Arguments
c::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}
x::Array{Float32, 4}
Locals
#198::Flux.var"#198#199"
cdims::DenseConvDims{2}
b::Array{Float32, 4}
σ::typeof(identity)
Body::Any
1 ─ %1 = Base.getproperty(c, :σ)::Core.Const(identity)
│ %2 = Base.getproperty(c, :bias)::Vector{Float32}
│ %3 = Core.tuple(%2)::Tuple{Vector{Float32}}
│ (#198 = %new(Flux.:(var"#198#199")))
│ %5 = #198::Core.Const(Flux.var"#198#199"())
│ %6 = Base.getproperty(c, :stride)::Tuple{Int64, Int64}
│ %7 = Flux.length(%6)::Core.Const(2)
│ %8 = Flux.ntuple(%5, %7)::Core.Const((1, 1))
│ %9 = Core.tuple(Flux.:(:), 1)::Core.Const((Colon(), 1))
│ %10 = Core._apply_iterate(Base.iterate, Flux.reshape, %3, %8, %9)::Array{Float32, 4}
│ (σ = %1)
│ (b = %10)
│ %13 = (:stride, :padding, :dilation, :groups)::Core.Const((:stride, :padding, :dilation, :groups))
│ %14 = Core.apply_type(Core.NamedTuple, %13)::Core.Const(NamedTuple{(:stride, :padding, :dilation, :groups)})
│ %15 = Base.getproperty(c, :stride)::Tuple{Int64, Int64}
│ %16 = Base.getproperty(c, :pad)::NTuple{4, Int64}
│ %17 = Base.getproperty(c, :dilation)::Tuple{Int64, Int64}
│ %18 = Base.getproperty(c, :groups)::Int64
│ %19 = Core.tuple(%15, %16, %17, %18)::Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}
│ %20 = (%14)(%19)::NamedTuple{(:stride, :padding, :dilation, :groups), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}}
│ %21 = Core.kwfunc(Flux.DenseConvDims)::Core.Const(Core.var"#Type##kw"())
│ %22 = Base.getproperty(c, :weight)::Array{Float32, 4}
│ (cdims = (%21)(%20, Flux.DenseConvDims, x, %22))
│ %24 = σ::Core.Const(identity)
│ %25 = Base.getproperty(c, :weight)::Array{Float32, 4}
│ %26 = Flux.conv(x, %25, cdims)::AbstractArray{yT, 4} where yT
│ %27 = Base.broadcasted(Flux.:+, %26, b)::Any
│ %28 = Base.broadcasted(%24, %27)::Any
│ %29 = Base.materialize(%28)::Any
└── return %29 After: @code_warntypejulia> c = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3));
julia> x = zeros(Float32, 28, 28, 1, 5);
julia> @code_warntype c(x)
MethodInstance for (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(::Array{Float32, 4})
from (c::Conv)(x::AbstractArray) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/conv.jl:163
Arguments
c::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}
x::Array{Float32, 4}
Locals
#198::Flux.var"#198#199"
cdims::DenseConvDims{2, 2, 2, 4, 2}
b::Array{Float32, 4}
σ::typeof(identity)
Body::Array{Float32, 4}
1 ─ %1 = Base.getproperty(c, :σ)::Core.Const(identity)
│ %2 = Base.getproperty(c, :bias)::Vector{Float32}
│ %3 = Core.tuple(%2)::Tuple{Vector{Float32}}
│ (#198 = %new(Flux.:(var"#198#199")))
│ %5 = #198::Core.Const(Flux.var"#198#199"())
│ %6 = Base.getproperty(c, :stride)::Tuple{Int64, Int64}
│ %7 = Flux.length(%6)::Core.Const(2)
│ %8 = Flux.ntuple(%5, %7)::Core.Const((1, 1))
│ %9 = Core.tuple(Flux.:(:), 1)::Core.Const((Colon(), 1))
│ %10 = Core._apply_iterate(Base.iterate, Flux.reshape, %3, %8, %9)::Array{Float32, 4}
│ (σ = %1)
│ (b = %10)
│ %13 = (:stride, :padding, :dilation, :groups)::Core.Const((:stride, :padding, :dilation, :groups))
│ %14 = Core.apply_type(Core.NamedTuple, %13)::Core.Const(NamedTuple{(:stride, :padding, :dilation, :groups)})
│ %15 = Base.getproperty(c, :stride)::Tuple{Int64, Int64}
│ %16 = Base.getproperty(c, :pad)::NTuple{4, Int64}
│ %17 = Base.getproperty(c, :dilation)::Tuple{Int64, Int64}
│ %18 = Base.getproperty(c, :groups)::Int64
│ %19 = Core.tuple(%15, %16, %17, %18)::Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}
│ %20 = (%14)(%19)::NamedTuple{(:stride, :padding, :dilation, :groups), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}}
│ %21 = Core.kwfunc(Flux.DenseConvDims)::Core.Const(Core.var"#Type##kw"())
│ %22 = Base.getproperty(c, :weight)::Array{Float32, 4}
│ (cdims = (%21)(%20, Flux.DenseConvDims, x, %22))
│ %24 = σ::Core.Const(identity)
│ %25 = Base.getproperty(c, :weight)::Array{Float32, 4}
│ %26 = Flux.conv(x, %25, cdims)::Array{Float32, 4}
│ %27 = Base.broadcasted(Flux.:+, %26, b)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4}, Nothing, typeof(+), Tuple{Array{Float32, 4}, Array{Float32, 4}}}
│ %28 = Base.broadcasted(%24, %27)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4}, Nothing, typeof(identity), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4}, Nothing, typeof(+), Tuple{Array{Float32, 4}, Array{Float32, 4}}}}}
│ %29 = Base.materialize(%28)::Array{Float32, 4}
└── return %29 |
I've used |
That's right, it would be helpful to see runtime differences. |
Runtime differencesTimings above include runtime differences Before:CPU
GPU
After:CPU
GPU
|
Maybe for now it is ok, since the runtime is not affected. But in the future PRs this can be removed, once other |
Since this is a breaking change, we can also tolerate the union since there will be follow-up PRs for other subtypes of ConvDims as well. Just need to make sure the union is gone before cutting a release :) |
Fixed type-inference for pooling and depthwise convolution as well. Side note: wanted to make resnet model fully inferrable, but turns out, @code_warntypejulia> b = BatchNorm(1);
julia> x = zeros(Float32, 28, 28, 1, 1);
julia> @code_warntype b(x)
MethodInstance for (::BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}})(::Array{Float32, 4})
from (BN::BatchNorm)(x) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/normalise.jl:269
Arguments
BN::BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}
x::Array{Float32, 4}
Locals
#284::Flux.var"#284#285"{Array{Float32, 4}, Int64}
affine_shape::NTuple{4, Int64}
reduce_dims::Vector{Int64}
N::Int64
Body::Any
1 ─ Core.NewvarNode(:(#284))
│ Core.NewvarNode(:(affine_shape))
│ Core.NewvarNode(:(reduce_dims))
│ Core.NewvarNode(:(N))
│ %5 = Flux.ndims(x)::Core.Const(4)
│ %6 = (%5 - 1)::Core.Const(3)
│ %7 = Flux.size(x, %6)::Int64
│ %8 = Base.getproperty(BN, :chs)::Int64
│ %9 = (%7 == %8)::Bool
└── goto #3 if not %9
2 ─ goto #4
3 ─ %12 = Base.AssertionError("size(x, ndims(x) - 1) == BN.chs")::Any
└── Base.throw(%12)
4 ┄ (N = Flux.ndims(x))
│ %15 = (N::Core.Const(4) - 2)::Core.Const(2)
│ %16 = (1:%15)::Core.Const(1:2)
│ (reduce_dims = Base.vcat(%16, N::Core.Const(4)))
│ %18 = Flux.:(var"#284#285")::Core.Const(Flux.var"#284#285")
│ %19 = Core.typeof(x)::Core.Const(Array{Float32, 4})
│ %20 = Core.typeof(N::Core.Const(4))::Core.Const(Int64)
│ %21 = Core.apply_type(%18, %19, %20)::Core.Const(Flux.var"#284#285"{Array{Float32, 4}, Int64})
│ (#284 = %new(%21, x, N::Core.Const(4)))
│ %23 = #284::Core.PartialStruct(Flux.var"#284#285"{Array{Float32, 4}, Int64}, Any[Array{Float32, 4}, Core.Const(4)])
│ (affine_shape = Flux.ntuple(%23, N::Core.Const(4)))
│ %25 = (:reduce_dims, :affine_shape)::Core.Const((:reduce_dims, :affine_shape))
│ %26 = Core.apply_type(Core.NamedTuple, %25)::Core.Const(NamedTuple{(:reduce_dims, :affine_shape)})
│ %27 = Core.tuple(reduce_dims, affine_shape::Core.PartialStruct(NTuple{4, Int64}, Any[Core.Const(1), Core.Const(1), Int64, Core.Const(1)]))::Core.PartialStruct(Tuple{Vector{Int64}, NTuple{4, Int64}}, Any[Vector{Int64}, Core.PartialStruct(NTuple{4, Int64}, Any[Core.Const(1), Core.Const(1), Int64, Core.Const(1)])])
│ %28 = (%26)(%27)::Core.PartialStruct(NamedTuple{(:reduce_dims, :affine_shape), Tuple{Vector{Int64}, NTuple{4, Int64}}}, Any[Vector{Int64}, Core.PartialStruct(NTuple{4, Int64}, Any[Core.Const(1), Core.Const(1), Int64, Core.Const(1)])])
│ %29 = Core.kwfunc(Flux._norm_layer_forward)::Core.Const(Flux.var"#_norm_layer_forward##kw"())
│ %30 = (%29)(%28, Flux._norm_layer_forward, BN, x)::Any
└── return %30 |
Here's another example with more layers, including pooling: Codex = device(zeros(Float32, 224, 224, 1, 5))
y = device(zeros(Float32, 10, 5))
m = device(Chain(
Conv((3, 3), 1 => 2, relu),
MaxPool((2, 2)),
Conv((3, 3), 2 => 2, relu),
MaxPool((2, 2)),
Conv((3, 3), 2 => 3, relu),
Conv((3, 3), 3 => 3, relu),
Conv((3, 3), 3 => 3, relu),
Conv((3, 3), 3 => 3, relu),
Conv((3, 3), 3 => 3, relu),
MaxPool((2, 2)),
Conv((3, 3), 3 => 3, relu),
Conv((3, 3), 3 => 3, relu),
Conv((3, 3), 3 => 3, relu),
Conv((3, 3), 3 => 3, relu),
MaxPool((2, 2)),
x -> reshape(x, :, size(x, 4)),
Dense(147, 10), softmax))
θ = params(m)
@time m(x)
BenchmarkTools.@btime $m($x)
@time gradient(θ) do
Flux.crossentropy(m(x), y)
end
BenchmarkTools.@btime gradient($θ) do
Flux.crossentropy($m($x), $y)
end Timings in the order they appear in code above. Before:
8.336650 seconds (33.78 M allocations: 1.800 GiB, 9.91% gc time, 99.63% compilation time)
4.162 ms (803 allocations: 36.60 MiB)
59.726230 seconds (71.67 M allocations: 3.815 GiB, 2.51% gc time, 99.84% compilation time)
33.088 ms (4586 allocations: 107.94 MiB)
26.857145 seconds (68.45 M allocations: 3.557 GiB, 5.80% gc time, 56.44% compilation time)
309.557 μs (1225 allocations: 70.75 KiB)
67.174837 seconds (104.17 M allocations: 5.416 GiB, 3.66% gc time, 93.56% compilation time)
2.277 ms (6496 allocations: 1.11 MiB) After:
2.510704 seconds (9.98 M allocations: 550.919 MiB, 7.70% gc time, 99.62% compilation time)
5.438 ms (446 allocations: 36.58 MiB)
42.624773 seconds (65.89 M allocations: 3.536 GiB, 3.14% gc time, 99.80% compilation time)
36.277 ms (3650 allocations: 107.93 MiB)
22.459811 seconds (56.46 M allocations: 2.928 GiB, 6.11% gc time, 51.20% compilation time)
263.607 μs (1066 allocations: 56.62 KiB)
59.634783 seconds (99.78 M allocations: 5.186 GiB, 4.05% gc time, 92.60% compilation time)
2.193 ms (6125 allocations: 1.11 MiB) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RE @inferred
, I noticed the tests pass now so is that no longer an issue? It's likely you were running into JuliaLang/julia#23749, so I'd recommend Cthulhu.jl as the most reliable way to debug inference.
RE BatchNorm, the revised implementation in FluxML/Flux.jl#1509 may be a starting point (if not a complete solution) for fixing inference. Another approach could be to first create a type stable functional version in NNlib (simultaneously addressing #19) and then retrofitting the Flux layer on top of that.
Strangely, it is working now... Added more inference tests. |
Ok, I think it's actually pooling, not
A comparison of profiles between both showed https://github.com/FluxML/NNlib.jl/blob/v0.7.31/src/impl/pooling_direct.jl#L53 post-PR but not on master. My suspicion is that the loop is being unrolled, as demonstrated by this simple example: function naive_pool(x::Vector{Float32}, ks::Int)
@inbounds for x_i in 1:length(x) - ks
m = x[x_i]
for k_i in 1:ks
m = max(m, x[x_i + k_i])
end
x[x_i] = m
end
end
function naive_pool_val(x::Vector{Float32}, ks::Val{K}) where K
ks = K
for x_i in 1:length(x) - ks
m = x[x_i]
for k_i in 1:ks
m = max(m, x[x_i + k_i])
end
x[x_i] = m
end
end
x = rand(Float32, 1024)
@code_llvm naive_pool(x, 3) # not unrolled
@code_llvm naive_pool_val(x, Val(3)) # unrolled, multiple references to max Since |
I managed to shave off ~150μs by using |
Just to confirm that I understand correctly, that parameters would be in dims type as they were before, but they will be passed onto them from, for example,
In this PR I did fix type-instability for |
Pinging @staticfloat for his eyes on the type stability as well. |
With the latest changes I've managed to nearly match the performance.
62.949 μs (13 allocations: 128.97 KiB)
39.504 μs (13 allocations: 128.97 KiB)
1.060 ms (17 allocations: 1.15 MiB)
133.712 μs (17 allocations: 1.15 MiB)
77.597 μs (9 allocations: 128.92 KiB)
74.611 μs (9 allocations: 128.92 KiB)
796.221 μs (13 allocations: 1.15 MiB)
124.385 μs (13 allocations: 1.15 MiB) First evaluation is slightly slower than it was, but still much faster than on master.
8.336650 seconds (33.78 M allocations: 1.800 GiB, 9.91% gc time, 99.63% compilation time)
4.162 ms (803 allocations: 36.60 MiB)
59.726230 seconds (71.67 M allocations: 3.815 GiB, 2.51% gc time, 99.84% compilation time)
33.088 ms (4586 allocations: 107.94 MiB)
3.406770 seconds (10.15 M allocations: 560.621 MiB, 7.04% gc time, 99.77% compilation time)
3.839 ms (453 allocations: 36.58 MiB)
58.345734 seconds (63.01 M allocations: 3.400 GiB, 2.49% gc time, 99.91% compilation time)
32.999 ms (3665 allocations: 107.93 MiB) |
Stupendous, thank you @pxl-th for taking this on! I have a couple of tiny non-behavioural suggestions, but other than that I think this is ready to go in :) |
elseif $(name == :mean) | ||
m += x[input_kw, input_kh, input_kd, c, batch_idx] | ||
else | ||
error("Unimplemented codegen path") | ||
end | ||
end | ||
y[w, h, d, c, batch_idx] = alpha * m + beta * y[w, h, d, c, batch_idx] | ||
|
||
y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can you eliminate the + beta * y
term here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC it was a reaction to this line: https://github.com/FluxML/NNlib.jl/blob/master/src/impl/pooling_direct.jl#L9.
I just went on and added back other parameters as
62.577 μs (13 allocations: 128.97 KiB)
39.254 μs (13 allocations: 128.97 KiB)
1.051 ms (17 allocations: 1.15 MiB)
123.622 μs (17 allocations: 1.15 MiB)
40.165 μs (13 allocations: 129.06 KiB)
40.146 μs (13 allocations: 129.06 KiB)
825.402 μs (13 allocations: 1.15 MiB)
134.233 μs (13 allocations: 1.15 MiB)
3.014 ms (15 allocations: 1.13 MiB)
10.432 ms (29 allocations: 3.43 MiB)
2.993 ms (26 allocations: 1.15 MiB)
2.760 ms (15 allocations: 1.13 MiB)
10.396 ms (23 allocations: 3.43 MiB)
2.881 ms (19 allocations: 1.15 MiB) Code:function main()
x = rand(Float32, 224, 224, 3, 2)
w = rand(Float32, 3, 3, 3, 3)
pdims = PoolDims(x, 3)
cdims = DenseConvDims(x, w)
y_max = maxpool(x, pdims)
y_mean = meanpool(x, pdims)
dy = ones(Float32, size(y_max)...)
@btime maxpool($x, $pdims)
@btime meanpool($x, $pdims)
@btime ∇maxpool($dy, $y_max, $x, $pdims)
@btime ∇meanpool($dy, $y_mean, $x, $pdims)
dy_conv = conv_direct(x, w, cdims)
@btime conv_direct($x, $w, $cdims)
@btime ∇conv_filter_direct($x, $dy_conv, $cdims)
@btime ∇conv_data_direct($dy_conv, $w, $cdims)
nothing
end
Could've been just noise, because right now I'm seeing timings around |
@ToucheSir do you have any further comments? otherwise looks good to go |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't read all of this closely, but the numbers look good. Two questions:
Pinging to see if there are still things that need to be addressed |
Haven't followed this thread closely. Are we waiting to merge cause it is breaking? |
Are there any PRs we want to flush out the queue before a breaking release? If not, I vote to merge. |
It is breaking in the sense that |
Let's merge this, and release after updates to the various parts of the ecosystem that specialise on the type params to avoid breakages. |
Seems like there are unintentional breakages https://buildkite.com/julialang/nnlibcuda-dot-jl/builds/149#b44502db-331d-457c-898e-db25d0405df4 |
I see it is now fixed, thanks @DhairyaLGandhi! |
This PR fixes type-stability for convolutions and pooling (and closes FluxML/Flux.jl#1178). It replaces current
ConvDims
structure and its descendants with the one that does not encode parameters in its type, but instead stores them as struct fields. This allows for the type inference and significantly improves compile times.Code used to benchmark:
Benchmarks:
Before:
After:
Results from
@code_warntype
Before
After