Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix convolution & pooling type-stability #370

Merged
merged 23 commits into from
Jan 23, 2022
Merged

Fix convolution & pooling type-stability #370

merged 23 commits into from
Jan 23, 2022

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Jan 8, 2022

This PR fixes type-stability for convolutions and pooling (and closes FluxML/Flux.jl#1178). It replaces current ConvDims structure and its descendants with the one that does not encode parameters in its type, but instead stores them as struct fields. This allows for the type inference and significantly improves compile times.


Code used to benchmark:
function main()
    device = gpu

    x = device(zeros(Float32, 28, 28, 1, 5))
    y = device(zeros(Float32, 10, 5))
    m = device(Chain(
        Conv((3, 3), 1 => 2, relu),
        Conv((3, 3), 2 => 2, relu),
        Conv((3, 3), 2 => 3, relu),
        Conv((3, 3), 3 => 3, relu),
        Conv((3, 3), 3 => 3, relu),
        Conv((3, 3), 3 => 3, relu),
        x -> reshape(x, :, size(x, 4)),
        Dense(768, 10), softmax))
    θ = params(m)

    # F
    @time m(x)
    # FB
    BenchmarkTools.@btime $m($x)
    # B
    @time gradient(θ) do
        Flux.crossentropy(m(x), y)
    end
end

Benchmarks:

Before:

  • CPU:
(F)  7.112869 seconds (31.64 M allocations: 1.649 GiB, 11.51% gc time, 99.94% compilation time)
(FB) 315.728 μs (387 allocations: 1.86 MiB)
(B)  33.839269 seconds (70.09 M allocations: 3.635 GiB, 3.73% gc time, 99.95% compilation time)
  • GPU:
(F)  32.785819 seconds (63.27 M allocations: 3.296 GiB, 5.00% gc time, 51.70% compilation time)
(FB) 145.766 μs (617 allocations: 32.97 KiB)
(B)  47.313573 seconds (103.05 M allocations: 5.359 GiB, 5.16% gc time, 91.07% compilation time)

After:

  • CPU:
(F)  2.725777 seconds (9.29 M allocations: 478.439 MiB, 6.42% gc time, 99.94% compilation time)
(FB) 285.336 μs (231 allocations: 1.86 MiB)
(B)  32.017971 seconds (64.42 M allocations: 3.362 GiB, 3.73% gc time, 99.97% compilation time)
  • GPU:
(F)  26.145839 seconds (57.83 M allocations: 3.002 GiB, 4.76% gc time, 44.62% compilation time)
(FB) 129.965 μs (555 allocations: 29.50 KiB)
(B)  44.136454 seconds (98.40 M allocations: 5.114 GiB, 5.47% gc time, 90.51% compilation time)

Results from @code_warntype

Before
julia> @code_warntype m(x)
MethodInstance for (::Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}})(::Array{Float32, 4})
  from (c::Chain)(x) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/basic.jl:49
Arguments
  c::Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}
  x::Array{Float32, 4}
Body::Any
1%1 = Base.getproperty(c, :layers)::Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}%2 = Flux.Tuple(%1)::Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}%3 = Flux.applychain(%2, x)::Any
└──      return %3
After
julia> @code_warntype m(x)
MethodInstance for (::Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}})(::Array{Float32, 4})
  from (c::Chain)(x) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/basic.jl:49
Arguments
  c::Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}
  x::Array{Float32, 4}
Body::Matrix{Float32}
1%1 = Base.getproperty(c, :layers)::Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}%2 = Flux.Tuple(%1)::Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Flux.var"#312#314", Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}%3 = Flux.applychain(%2, x)::Matrix{Float32}
└──      return %3

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jan 8, 2022

This looks really good! I believe you'll need to use BenchmarkTools for the F and B numbers too. Which Julia version was the Any inferred with?

@DhairyaLGandhi
Copy link
Member

We probably don't want to have small Unions be part of the dispatch.

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move the syntax to match the current one? It's harder to look for the most important changes when there are unrelated changes in a PR. The more to the point the changes, the better

src/conv.jl Outdated Show resolved Hide resolved
src/conv.jl Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

Re note: I believe in most cases that means that the inference is actually going fine.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 8, 2022

Which Julia version was the Any inferred with?

I'm using 1.7.1. But issue is present on previous versions as well.

In the PR description I've posted @code_warntype for the model, but here's also for the single convolutional layer.

Before:

Here you can see, that DenseConvDims is not inferred fully, having only spatial_dims in its type.

@code_warntype
julia> c = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3));
julia> x = zeros(Float32, 28, 28, 1, 5);
julia> @code_warntype c(x)
MethodInstance for (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(::Array{Float32, 4})
  from (c::Conv)(x::AbstractArray) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/conv.jl:163
Arguments
  c::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}
  x::Array{Float32, 4}
Locals
  #198::Flux.var"#198#199"
  cdims::DenseConvDims{2}
  b::Array{Float32, 4}
  σ::typeof(identity)
Body::Any
1%1  = Base.getproperty(c, )::Core.Const(identity)
│   %2  = Base.getproperty(c, :bias)::Vector{Float32}%3  = Core.tuple(%2)::Tuple{Vector{Float32}}
│         (#198 = %new(Flux.:(var"#198#199")))%5  = #198::Core.Const(Flux.var"#198#199"())%6  = Base.getproperty(c, :stride)::Tuple{Int64, Int64}%7  = Flux.length(%6)::Core.Const(2)
│   %8  = Flux.ntuple(%5, %7)::Core.Const((1, 1))
│   %9  = Core.tuple(Flux.:(:), 1)::Core.Const((Colon(), 1))
│   %10 = Core._apply_iterate(Base.iterate, Flux.reshape, %3, %8, %9)::Array{Float32, 4}
│         (σ = %1)
│         (b = %10)
│   %13 = (:stride, :padding, :dilation, :groups)::Core.Const((:stride, :padding, :dilation, :groups))
│   %14 = Core.apply_type(Core.NamedTuple, %13)::Core.Const(NamedTuple{(:stride, :padding, :dilation, :groups)})
│   %15 = Base.getproperty(c, :stride)::Tuple{Int64, Int64}%16 = Base.getproperty(c, :pad)::NTuple{4, Int64}%17 = Base.getproperty(c, :dilation)::Tuple{Int64, Int64}%18 = Base.getproperty(c, :groups)::Int64%19 = Core.tuple(%15, %16, %17, %18)::Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}%20 = (%14)(%19)::NamedTuple{(:stride, :padding, :dilation, :groups), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}}%21 = Core.kwfunc(Flux.DenseConvDims)::Core.Const(Core.var"#Type##kw"())
│   %22 = Base.getproperty(c, :weight)::Array{Float32, 4}
│         (cdims = (%21)(%20, Flux.DenseConvDims, x, %22))
│   %24 = σ::Core.Const(identity)
│   %25 = Base.getproperty(c, :weight)::Array{Float32, 4}%26 = Flux.conv(x, %25, cdims)::AbstractArray{yT, 4} where yT
│   %27 = Base.broadcasted(Flux.:+, %26, b)::Any%28 = Base.broadcasted(%24, %27)::Any%29 = Base.materialize(%28)::Any
└──       return %29

After:

@code_warntype
julia> c = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3));
julia> x = zeros(Float32, 28, 28, 1, 5);
julia> @code_warntype c(x)
MethodInstance for (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(::Array{Float32, 4})
  from (c::Conv)(x::AbstractArray) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/conv.jl:163
Arguments
  c::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}
  x::Array{Float32, 4}
Locals
  #198::Flux.var"#198#199"
  cdims::DenseConvDims{2, 2, 2, 4, 2}
  b::Array{Float32, 4}
  σ::typeof(identity)
Body::Array{Float32, 4}
1%1  = Base.getproperty(c, )::Core.Const(identity)
│   %2  = Base.getproperty(c, :bias)::Vector{Float32}%3  = Core.tuple(%2)::Tuple{Vector{Float32}}
│         (#198 = %new(Flux.:(var"#198#199")))%5  = #198::Core.Const(Flux.var"#198#199"())%6  = Base.getproperty(c, :stride)::Tuple{Int64, Int64}%7  = Flux.length(%6)::Core.Const(2)
│   %8  = Flux.ntuple(%5, %7)::Core.Const((1, 1))
│   %9  = Core.tuple(Flux.:(:), 1)::Core.Const((Colon(), 1))
│   %10 = Core._apply_iterate(Base.iterate, Flux.reshape, %3, %8, %9)::Array{Float32, 4}
│         (σ = %1)
│         (b = %10)
│   %13 = (:stride, :padding, :dilation, :groups)::Core.Const((:stride, :padding, :dilation, :groups))
│   %14 = Core.apply_type(Core.NamedTuple, %13)::Core.Const(NamedTuple{(:stride, :padding, :dilation, :groups)})
│   %15 = Base.getproperty(c, :stride)::Tuple{Int64, Int64}%16 = Base.getproperty(c, :pad)::NTuple{4, Int64}%17 = Base.getproperty(c, :dilation)::Tuple{Int64, Int64}%18 = Base.getproperty(c, :groups)::Int64%19 = Core.tuple(%15, %16, %17, %18)::Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}%20 = (%14)(%19)::NamedTuple{(:stride, :padding, :dilation, :groups), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}}%21 = Core.kwfunc(Flux.DenseConvDims)::Core.Const(Core.var"#Type##kw"())
│   %22 = Base.getproperty(c, :weight)::Array{Float32, 4}
│         (cdims = (%21)(%20, Flux.DenseConvDims, x, %22))
│   %24 = σ::Core.Const(identity)
│   %25 = Base.getproperty(c, :weight)::Array{Float32, 4}%26 = Flux.conv(x, %25, cdims)::Array{Float32, 4}%27 = Base.broadcasted(Flux.:+, %26, b)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4}, Nothing, typeof(+), Tuple{Array{Float32, 4}, Array{Float32, 4}}}%28 = Base.broadcasted(%24, %27)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4}, Nothing, typeof(identity), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4}, Nothing, typeof(+), Tuple{Array{Float32, 4}, Array{Float32, 4}}}}}
│   %29 = Base.materialize(%28)::Array{Float32, 4}
└──       return %29

@pxl-th
Copy link
Member Author

pxl-th commented Jan 8, 2022

I believe you'll need to use BenchmarkTools for the Fand B numbers too.

I've used @time to get the measurements for the first run, which includes compilation. I believe @btime runs several times and does not include compilation times in its results (correct me if I'm wrong).

@DhairyaLGandhi
Copy link
Member

That's right, it would be helpful to see runtime differences.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 8, 2022

Runtime differences

Timings above include runtime differences (FB), but here's also for the backward pass.
Using @btime for the code in the PR description.

Before:

CPU

  • Forward: 313.752 μs (387 allocations: 1.86 MiB)
  • Backward: 2.211 ms (2453 allocations: 5.81 MiB)

GPU

  • Forward: 161.414 μs (621 allocations: 33.09 KiB)
  • Backward: 1.278 ms (3676 allocations: 459.58 KiB)

After:

CPU

  • Forward: 282.443 μs (231 allocations: 1.86 MiB)
  • Backward: 2.285 ms (2015 allocations: 5.79 MiB)

GPU

  • Forward: 126.859 μs (555 allocations: 29.50 KiB)
  • Backward: 1.202 ms (3518 allocations: 447.41 KiB)

@pxl-th
Copy link
Member Author

pxl-th commented Jan 8, 2022

We probably don't want to have small Unions be part of the dispatch.

Maybe for now it is ok, since the runtime is not affected. But in the future PRs this can be removed, once other ConvDims are fixed

src/dim_helpers/DenseConvDims.jl Outdated Show resolved Hide resolved
src/dim_helpers/DenseConvDims.jl Show resolved Hide resolved
@ToucheSir
Copy link
Member

Since this is a breaking change, we can also tolerate the union since there will be follow-up PRs for other subtypes of ConvDims as well. Just need to make sure the union is gone before cutting a release :)

@pxl-th pxl-th changed the title Fix conv type-stability Fix convolution & pooling type-stability Jan 8, 2022
@pxl-th
Copy link
Member Author

pxl-th commented Jan 8, 2022

Fixed type-inference for pooling and depthwise convolution as well.
Might be a lot, but this way we don't have to worry about multiple PRs and Unions.

Side note: wanted to make resnet model fully inferrable, but turns out, BatchNorm is also not inferrable:

@code_warntype
julia> b = BatchNorm(1);
julia> x = zeros(Float32, 28, 28, 1, 1);
julia> @code_warntype b(x)
MethodInstance for (::BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}})(::Array{Float32, 4})
  from (BN::BatchNorm)(x) in Flux at /home/pxl-th/.julia/dev/Flux/src/layers/normalise.jl:269
Arguments
  BN::BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}
  x::Array{Float32, 4}
Locals
  #284::Flux.var"#284#285"{Array{Float32, 4}, Int64}
  affine_shape::NTuple{4, Int64}
  reduce_dims::Vector{Int64}
  N::Int64
Body::Any
1 ─       Core.NewvarNode(:(#284))
│         Core.NewvarNode(:(affine_shape))
│         Core.NewvarNode(:(reduce_dims))
│         Core.NewvarNode(:(N))
│   %5  = Flux.ndims(x)::Core.Const(4)
│   %6  = (%5 - 1)::Core.Const(3)
│   %7  = Flux.size(x, %6)::Int64%8  = Base.getproperty(BN, :chs)::Int64%9  = (%7 == %8)::Bool
└──       goto #3 if not %9
2 ─       goto #4
3%12 = Base.AssertionError("size(x, ndims(x) - 1) == BN.chs")::Any
└──       Base.throw(%12)
4 ┄       (N = Flux.ndims(x))
│   %15 = (N::Core.Const(4) - 2)::Core.Const(2)
│   %16 = (1:%15)::Core.Const(1:2)
│         (reduce_dims = Base.vcat(%16, N::Core.Const(4)))
│   %18 = Flux.:(var"#284#285")::Core.Const(Flux.var"#284#285")
│   %19 = Core.typeof(x)::Core.Const(Array{Float32, 4})
│   %20 = Core.typeof(N::Core.Const(4))::Core.Const(Int64)
│   %21 = Core.apply_type(%18, %19, %20)::Core.Const(Flux.var"#284#285"{Array{Float32, 4}, Int64})
│         (#284 = %new(%21, x, N::Core.Const(4)))%23 = #284::Core.PartialStruct(Flux.var"#284#285"{Array{Float32, 4}, Int64}, Any[Array{Float32, 4}, Core.Const(4)])
│         (affine_shape = Flux.ntuple(%23, N::Core.Const(4)))
│   %25 = (:reduce_dims, :affine_shape)::Core.Const((:reduce_dims, :affine_shape))
│   %26 = Core.apply_type(Core.NamedTuple, %25)::Core.Const(NamedTuple{(:reduce_dims, :affine_shape)})
│   %27 = Core.tuple(reduce_dims, affine_shape::Core.PartialStruct(NTuple{4, Int64}, Any[Core.Const(1), Core.Const(1), Int64, Core.Const(1)]))::Core.PartialStruct(Tuple{Vector{Int64}, NTuple{4, Int64}}, Any[Vector{Int64}, Core.PartialStruct(NTuple{4, Int64}, Any[Core.Const(1), Core.Const(1), Int64, Core.Const(1)])])
│   %28 = (%26)(%27)::Core.PartialStruct(NamedTuple{(:reduce_dims, :affine_shape), Tuple{Vector{Int64}, NTuple{4, Int64}}}, Any[Vector{Int64}, Core.PartialStruct(NTuple{4, Int64}, Any[Core.Const(1), Core.Const(1), Int64, Core.Const(1)])])
│   %29 = Core.kwfunc(Flux._norm_layer_forward)::Core.Const(Flux.var"#_norm_layer_forward##kw"())
│   %30 = (%29)(%28, Flux._norm_layer_forward, BN, x)::Any
└──       return %30

@pxl-th
Copy link
Member Author

pxl-th commented Jan 9, 2022

Here's another example with more layers, including pooling:

Code
x = device(zeros(Float32, 224, 224, 1, 5))
y = device(zeros(Float32, 10, 5))
m = device(Chain(
    Conv((3, 3), 1 => 2, relu),
    MaxPool((2, 2)),
    Conv((3, 3), 2 => 2, relu),
    MaxPool((2, 2)),
    Conv((3, 3), 2 => 3, relu),
    Conv((3, 3), 3 => 3, relu),
    Conv((3, 3), 3 => 3, relu),
    Conv((3, 3), 3 => 3, relu),
    Conv((3, 3), 3 => 3, relu),
    MaxPool((2, 2)),
    Conv((3, 3), 3 => 3, relu),
    Conv((3, 3), 3 => 3, relu),
    Conv((3, 3), 3 => 3, relu),
    Conv((3, 3), 3 => 3, relu),
    MaxPool((2, 2)),
    x -> reshape(x, :, size(x, 4)),
    Dense(147, 10), softmax))
θ = params(m)

@time m(x)
BenchmarkTools.@btime $m($x)

@time gradient(θ) do
    Flux.crossentropy(m(x), y)
end
BenchmarkTools.@btime gradient($θ) do
    Flux.crossentropy($m($x), $y)
end

Timings in the order they appear in code above.

Before:

  • CPU:
8.336650 seconds (33.78 M allocations: 1.800 GiB, 9.91% gc time, 99.63% compilation time)
4.162 ms (803 allocations: 36.60 MiB)
59.726230 seconds (71.67 M allocations: 3.815 GiB, 2.51% gc time, 99.84% compilation time)
33.088 ms (4586 allocations: 107.94 MiB)
  • GPU:
26.857145 seconds (68.45 M allocations: 3.557 GiB, 5.80% gc time, 56.44% compilation time)
309.557 μs (1225 allocations: 70.75 KiB)
67.174837 seconds (104.17 M allocations: 5.416 GiB, 3.66% gc time, 93.56% compilation time)
2.277 ms (6496 allocations: 1.11 MiB)

After:

  • CPU:
2.510704 seconds (9.98 M allocations: 550.919 MiB, 7.70% gc time, 99.62% compilation time)
5.438 ms (446 allocations: 36.58 MiB)
42.624773 seconds (65.89 M allocations: 3.536 GiB, 3.14% gc time, 99.80% compilation time)
36.277 ms (3650 allocations: 107.93 MiB)
  • GPU:
22.459811 seconds (56.46 M allocations: 2.928 GiB, 6.11% gc time, 51.20% compilation time)
263.607 μs (1066 allocations: 56.62 KiB)
59.634783 seconds (99.78 M allocations: 5.186 GiB, 4.05% gc time, 92.60% compilation time)
2.193 ms (6125 allocations: 1.11 MiB)

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE @inferred, I noticed the tests pass now so is that no longer an issue? It's likely you were running into JuliaLang/julia#23749, so I'd recommend Cthulhu.jl as the most reliable way to debug inference.

RE BatchNorm, the revised implementation in FluxML/Flux.jl#1509 may be a starting point (if not a complete solution) for fixing inference. Another approach could be to first create a type stable functional version in NNlib (simultaneously addressing #19) and then retrofitting the Flux layer on top of that.

src/dim_helpers/ConvDims.jl Outdated Show resolved Hide resolved
src/dim_helpers/DenseConvDims.jl Outdated Show resolved Hide resolved
src/dim_helpers/ConvDims.jl Outdated Show resolved Hide resolved
src/dim_helpers/ConvDims.jl Outdated Show resolved Hide resolved
@pxl-th
Copy link
Member Author

pxl-th commented Jan 9, 2022

RE @inferred, I noticed the tests pass now so is that no longer an issue? It's likely you were running into JuliaLang/julia#23749, > so I'd recommend Cthulhu.jl as the most reliable way to debug inference.

Strangely, it is working now... Added more inference tests.

@ToucheSir
Copy link
Member

ToucheSir commented Jan 9, 2022

Ok, I think it's actually pooling, not conv that has regressed perf-wise:

# master
typeof(pdims) = PoolDims{2, (2, 2), (2, 2), (0, 0, 0, 0), (1, 1)}
  91.530 μs (13 allocations: 245.59 KiB)

# PR
typeof(pdims) = PoolDims{2, 2, 2, 4, 2}
  732.686 μs (7 allocations: 245.36 KiB)

A comparison of profiles between both showed https://github.com/FluxML/NNlib.jl/blob/v0.7.31/src/impl/pooling_direct.jl#L53 post-PR but not on master. My suspicion is that the loop is being unrolled, as demonstrated by this simple example:

function naive_pool(x::Vector{Float32}, ks::Int)
  @inbounds for x_i in 1:length(x) - ks
    m = x[x_i]
    for k_i in 1:ks
      m = max(m, x[x_i + k_i])
    end
    x[x_i] = m
  end
end

function naive_pool_val(x::Vector{Float32}, ks::Val{K}) where K
  ks = K
  for x_i in 1:length(x) - ks
    m = x[x_i]
    for k_i in 1:ks
      m = max(m, x[x_i + k_i])
    end
    x[x_i] = m
  end
end

x = rand(Float32, 1024)
@code_llvm naive_pool(x, 3) # not unrolled
@code_llvm naive_pool_val(x, Val(3)) # unrolled, multiple references to max

Since conv_direct! and depthwiseconv_direct! use similar inner loops, they may also be affected. WDYT about (re-)promoting the kernel shape and potentially input channels to type space? I wonder if doing that at the last possible moment (i.e. right before calling the impl function) would work (presumably not?)...

@ToucheSir
Copy link
Member

I managed to shave off ~150μs by using CartesianIndices instead of the multi-range iterator, but that's still a good 6x slower than master. The conundrum is that Flux doesn't keep track of kernel size at the type level. If possible, I'd like to solicit @chriselrod's opinion on whether it's possible to bridge the performance gap while keeping type inference happy.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 11, 2022

I think we should store it at model definition in order to make DenseConvDims type stable, without giving up runtime performance.

Just to confirm that I understand correctly, that parameters would be in dims type as they were before, but they will be passed onto them from, for example, MaxPool instead of being calculated in the dims constructor?

Ok, the dims constructors will definitely need a rework then.

stride, padding, dilation = check_spdf(x_size, (k..., 1, 1), stride, padding, dilation)

in particular is an absolute black hole for type instability, and I haven't yet figured yet how to make it behave...

In this PR I did fix type-instability for PoolDims constructor, maybe some of it can be used in the rework.

@DhairyaLGandhi
Copy link
Member

Pinging @staticfloat for his eyes on the type stability as well.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 12, 2022

With the latest changes I've managed to nearly match the performance.
Essentially I'm doing what @ToucheSir had suggested earlier: re-promoting kernel_size to type right before calling the kernel.

Pooling benchmarking:

  • Master:
62.949 μs (13 allocations: 128.97 KiB)
39.504 μs (13 allocations: 128.97 KiB)
1.060 ms (17 allocations: 1.15 MiB)
133.712 μs (17 allocations: 1.15 MiB)
  • Latest commit:
77.597 μs (9 allocations: 128.92 KiB)
74.611 μs (9 allocations: 128.92 KiB)
796.221 μs (13 allocations: 1.15 MiB)
124.385 μs (13 allocations: 1.15 MiB)

Model:

First evaluation is slightly slower than it was, but still much faster than on master.
I think this is because more optimizations are applied to the function... idk :)
Alternatively, because it gets compiled every time it is called with different kernel_size, which is not a case for the given model, though, but still can cause increased compilation times.

  • Master:
8.336650 seconds (33.78 M allocations: 1.800 GiB, 9.91% gc time, 99.63% compilation time)
4.162 ms (803 allocations: 36.60 MiB)
59.726230 seconds (71.67 M allocations: 3.815 GiB, 2.51% gc time, 99.84% compilation time)
33.088 ms (4586 allocations: 107.94 MiB)
  • Latest commit:
3.406770 seconds (10.15 M allocations: 560.621 MiB, 7.04% gc time, 99.77% compilation time)
3.839 ms (453 allocations: 36.58 MiB)
58.345734 seconds (63.01 M allocations: 3.400 GiB, 2.49% gc time, 99.91% compilation time)
32.999 ms (3665 allocations: 107.93 MiB)

src/impl/conv_im2col.jl Outdated Show resolved Hide resolved
@ToucheSir
Copy link
Member

Stupendous, thank you @pxl-th for taking this on! I have a couple of tiny non-behavioural suggestions, but other than that I think this is ready to go in :)

elseif $(name == :mean)
m += x[input_kw, input_kh, input_kd, c, batch_idx]
else
error("Unimplemented codegen path")
end
end
y[w, h, d, c, batch_idx] = alpha * m + beta * y[w, h, d, c, batch_idx]

y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can you eliminate the + beta * y term here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/impl/conv_direct.jl Outdated Show resolved Hide resolved
@pxl-th
Copy link
Member Author

pxl-th commented Jan 12, 2022

I just went on and added back other parameters as Val{...} and it now matches or surpasses the performance on master.
Backward passes for the pooling didn't see any improvements from this, so I've left them as is.

pooling_direct

  • Master:
62.577 μs (13 allocations: 128.97 KiB)
39.254 μs (13 allocations: 128.97 KiB)
1.051 ms (17 allocations: 1.15 MiB)
123.622 μs (17 allocations: 1.15 MiB)
  • This PR:
40.165 μs (13 allocations: 129.06 KiB)
40.146 μs (13 allocations: 129.06 KiB)
825.402 μs (13 allocations: 1.15 MiB)
134.233 μs (13 allocations: 1.15 MiB)

conv_direct

  • Master:
3.014 ms (15 allocations: 1.13 MiB)
10.432 ms (29 allocations: 3.43 MiB)
2.993 ms (26 allocations: 1.15 MiB)
  • This PR:
2.760 ms (15 allocations: 1.13 MiB)
10.396 ms (23 allocations: 3.43 MiB)
2.881 ms (19 allocations: 1.15 MiB)
Code:
function main()
    x = rand(Float32, 224, 224, 3, 2)
    w = rand(Float32, 3, 3, 3, 3)
    pdims = PoolDims(x, 3)
    cdims = DenseConvDims(x, w)

    y_max = maxpool(x, pdims)
    y_mean = meanpool(x, pdims)
    dy = ones(Float32, size(y_max)...)

    @btime maxpool($x, $pdims)
    @btime meanpool($x, $pdims)

    @btime ∇maxpool($dy, $y_max, $x, $pdims)
    @btime ∇meanpool($dy, $y_mean, $x, $pdims)

    dy_conv = conv_direct(x, w, cdims)

    @btime conv_direct($x, $w, $cdims)
    @btime ∇conv_filter_direct($x, $dy_conv, $cdims)
    @btime ∇conv_data_direct($dy_conv, $w, $cdims)

    nothing
end

First evaluation is slightly slower than it was, but still much faster than on master.

Could've been just noise, because right now I'm seeing timings around 2.7-2.9 seconds.

src/dim_helpers/ConvDims.jl Outdated Show resolved Hide resolved
src/conv.jl Outdated Show resolved Hide resolved
src/dim_helpers/ConvDims.jl Show resolved Hide resolved
src/impl/pooling_direct.jl Show resolved Hide resolved
src/impl/pooling_direct.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

@ToucheSir do you have any further comments? otherwise looks good to go

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't read all of this closely, but the numbers look good. Two questions:

src/dim_helpers/ConvDims.jl Outdated Show resolved Hide resolved
src/dim_helpers/ConvDims.jl Outdated Show resolved Hide resolved
@pxl-th
Copy link
Member Author

pxl-th commented Jan 21, 2022

Pinging to see if there are still things that need to be addressed

@darsnack
Copy link
Member

Haven't followed this thread closely. Are we waiting to merge cause it is breaking?

@ToucheSir
Copy link
Member

Are there any PRs we want to flush out the queue before a breaking release? If not, I vote to merge.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 23, 2022

Haven't followed this thread closely. Are we waiting to merge cause it is breaking?

It is breaking in the sense that ConvDims type parameters do not mean the same thing.
However method-wise it is non-breaking.

@DhairyaLGandhi
Copy link
Member

Let's merge this, and release after updates to the various parts of the ecosystem that specialise on the type params to avoid breakages.

@CarloLucibello CarloLucibello merged commit 8d3e059 into FluxML:master Jan 23, 2022
@DhairyaLGandhi
Copy link
Member

@pxl-th
Copy link
Member Author

pxl-th commented Jan 24, 2022

Seems like there are unintentional breakages https://buildkite.com/julialang/nnlibcuda-dot-jl/builds/149#b44502db-331d-457c-898e-db25d0405df4

I see it is now fixed, thanks @DhairyaLGandhi!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Flux.Conv type instability
8 participants