diff --git a/.travis.yml b/.travis.yml index e8da6fc66..f6b30fb91 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,6 +9,9 @@ notifications: email: false git: depth: 99999999 +env: + # Disable test fuzzing for the moment, as we're a little too slow for Travis + - NNLIB_TEST_FUZZING=false # Submit to Codecov after_success: diff --git a/Manifest.toml b/Manifest.toml index 276790ac3..4aed2cfd2 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,31 +1,22 @@ +# This file is machine-generated - editing it directly is not advised + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "ff2595695fc4f14427358ce2593f867085c45dcb" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "1.2.0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +[[Crayons]] +deps = ["Test"] +git-tree-sha1 = "3017c662a988bcb8a3f43306a793617c6524d476" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "1.0.0" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[LibGit2]] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -36,31 +27,14 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[MacroTools]] -deps = ["Compat"] -git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.4.4" - [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[Pkg]] -deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - [[Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - [[Random]] deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -71,16 +45,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "0.5.2" -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -96,9 +63,11 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[UUIDs]] -deps = ["Random"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +[[TimerOutputs]] +deps = ["Crayons", "Printf", "Test", "Unicode"] +git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.0" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/Project.toml b/Project.toml index 4ee32fbad..05a1f7b8d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [deps] Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" diff --git a/REQUIRE b/REQUIRE index 38f15676a..e4c1734f1 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,3 +1,2 @@ -julia 0.7- +julia 1.0 Requires -MacroTools diff --git a/src/NNlib.jl b/src/NNlib.jl index da0d24025..26980d972 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -1,16 +1,27 @@ module NNlib +using Requires, TimerOutputs -using Requires, Libdl - -export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid, - softmax, logsoftmax, maxpool, meanpool - -include("numeric.jl") +# Include APIs +include("dim_helpers.jl") include("activation.jl") include("softmax.jl") -include("logsoftmax.jl") -include("linalg.jl") +include("gemm.jl") include("conv.jl") -include("cubroadcast.jl") +include("pooling.jl") + +## Include implementations +include("impl/padding_edges.jl") + +# Direct implementations of convolutional and depthwise-convolutional algorithms +include("impl/conv_direct.jl") +include("impl/depthwiseconv_direct.jl") +# im2col implementations of convolutional and depthwise-convolutional algorithms +include("impl/conv_im2col.jl") +include("impl/depthwiseconv_im2col.jl") + +# Direct implementations of pooling +include("impl/pooling_direct.jl") + +to = TimerOutput() -end # module +end # module NNlib \ No newline at end of file diff --git a/src/activation.jl b/src/activation.jl index 19d5cda80..371db9019 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -1,3 +1,6 @@ +export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, + logsigmoid + """ σ(x) = 1 / (1 + exp(-x)) @@ -5,18 +8,16 @@ Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation function. """ σ(x) = one(x) / (one(x) + exp(-x)) - const sigmoid = σ # ForwardDiff numerical stability hack σ_stable(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x))) - σ(x::Float32) = σ_stable(x) - @init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x) end + """ logσ(x) @@ -31,13 +32,13 @@ Return `log(σ(x))` which is computed in a numerically stable way. -0.0 """ function logσ(x) - max_v = max(zero(x), -x) - z = exp(-max_v) + exp(-x-max_v) - -(max_v + log(z)) + max_v = max(zero(x), -x) + z = exp(-max_v) + exp(-x-max_v) + return -(max_v + log(z)) end - const logsigmoid = logσ + """ relu(x) = max(0, x) @@ -56,6 +57,7 @@ You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`. """ leakyrelu(x, a = oftype(x/1, 0.01)) = max(a*x, x/1) + """ elu(x, α = 1) = x > 0 ? x : α * (exp(x) - 1) @@ -66,6 +68,7 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`. """ elu(x, α = one(x)) = ifelse(x ≥ 0, x/1, α * (exp(x) - one(x))) + """ gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3))) @@ -103,6 +106,7 @@ function selu(x) λ * ifelse(x > 0, x/1, α * (exp(x) - 1)) end + """ softsign(x) = x / (1 + |x|) diff --git a/src/conv.jl b/src/conv.jl index 5f2cd848d..ada11ee35 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -1,266 +1,153 @@ -Dims{N} = NTuple{N,Integer} - -include("impl/pool.jl") -include("impl/conv.jl") - -# Convolutions - -function cdims(x::NTuple{N}, w::NTuple{N}, pad, stride) where N - ntuple(Val(N)) do i - if i < N-1 - 1 + div(x[i] - w[i] + 2*pad[i], stride[i]) - elseif i == N-1 - w[N] - else # i == N - x[N] +export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter! + +## Convolution API +# +# We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d, +# 2d and 3d convolutions, based on the rank of the input tensors, in both mutating and +# non-mutating auto-allocating variants: +# - Convolution: +# - conv(x, w, cdims) +# - conv!(y, x, w, cdims) +# - Convolution data backpropagation +# - ∇conv_data(dy, w, cdims) +# - ∇conv_data!(dx, dy, w, cdims) +# - Convolution filter backpropagation +# - ∇conv_filter(x, dy, cdims) +# - ∇conv_filter!(dw, x, dy, cdims) +# +# All methods require a `ConvDims` object to define the dimensions and optional +# elements of the convolution (padding, stride, dilation, kernel-flipping, etc...), +# which is easily constructable through something like `DenseConvDims(x, w)`. All +# methods take in the `ConvDims` of the associated normal, forward-pass convolution, +# that is, the following is legal: +# +# cdims = ConvDims(x, w; stride=2, dilation=(3,2)) +# dx = ∇conv_data(conv(x, w, cdims), w, cdims) + + + +# First, we will define mappings from the generic API names to our accelerated backend +# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using +# im2col + GEMM. Do so in a loop, here: +for (front_name, backend) in ( + # This maps from public, front-facing name, to internal backend name + :conv => :im2col, + :∇conv_data => :im2col, + :∇conv_filter => :im2col, + :depthwiseconv => :im2col, + :∇depthwiseconv_data => :im2col, + :∇depthwiseconv_filter => :im2col, + ) + + # These are the GEMM types we will accelerate with `im2col` + G = Union{[x[2] for x in gemm_datatype_mappings]...} + + # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution + @eval begin + # im2col-accelerated function forwarding definition + @timeit_debug to function $(Symbol("$(front_name)!"))( + out::AbstractArray{T,5}, in1::AbstractArray{T,5}, + in2::AbstractArray{T,5}, cdims::ConvDims; kwargs...) where {T <: $G} + $(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...) + end end - end end - -# Conv Transpose dims - -function ctdims(x::NTuple{N}, w::NTuple{N}, pad, stride, dilation) where N - ntuple(Val(N)) do i - if i < N-1 - (x[i] - 1) * stride[i] + dilation[i] * (w[i] - 1) - 2*pad[i] + 1 - elseif i == N-1 - w[N-1] - else # i == N - x[N] +# Our strategy for 1d and 2d convolution is to reshape to 3d convolutions, which +# makes things MUCH EASIER for us on the backend side, and is in general pretty fast, +# since we can specialize on sizes. +for front_name in (:conv, :∇conv_data, :∇conv_filter, + :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) + for backend in (Symbol(), :_direct, :_im2col) + for N in (3, 4) + @eval begin + function $(Symbol("$(front_name)$(backend)!"))( + y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N}, + w::AbstractArray{wT,$N}, cdims::ConvDims; + kwargs...) where {yT, xT, wT} + $(Symbol("$(front_name)$(backend)!"))( + insert_singleton_spatial_dimension(y, $(5 - N)), + insert_singleton_spatial_dimension(x, $(5 - N)), + insert_singleton_spatial_dimension(w, $(5 - N)), + insert_singleton_spatial_dimension(cdims, $(5 - N)); + kwargs... + ) + + # We explicitly return `y` here, because the backend call + # itself may return a reshaped view, which we don't want. + return y + end + end + end end - end end - -# Kernel dims - -function wdims(x::NTuple{N}, y::NTuple{N}, pad, stride, dilation) where N - ntuple(Val(N)) do i - if i < N-1 - 1 + div((1 - y[i]) * stride[i] + x[i] + 2pad[i] - 1, dilation[i]) - elseif i == N-1 - x[i] - else # i == N - y[i-1] +# We always support a fallback, non-accelerated path, where we use the direct, but +# slow, implementations. These should not typically be used, hence the `@debug`, +# but let's ggo ahead and define them first: +for front_name in (:conv, :∇conv_data, :∇conv_filter, + :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) + @eval begin + function $(Symbol("$(front_name)!"))( + y::AbstractArray{yT,N}, in1::AbstractArray{T1,N}, + in2::AbstractArray{T2,N}, cdims::ConvDims; + kwargs...) where {yT, T1, T2, N} + @debug string("Slow fallback implementation invoked for $(front_name)! ", + "You probably don't want this; check your datatypes.") + $(Symbol("$(front_name)_direct!"))(y, in1, in2, cdims; kwargs...) + end end - end -end - -# Interface - -head(x) = reverse(Base.tail(reverse(x))) -padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x))) -padtuple(x::Tuple,p::Tuple) = p -padtuple(x::AbstractArray,p) = padtuple(size(x),p) - -function conv(x::AbstractArray, w::AbstractArray; size=nothing, pad = 0, stride = 1, dilation = 1) - pad_, stride_ = padtuple(x, pad), padtuple(x, stride) - if size === nothing - size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_) - end - conv!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation) -end - -function crosscor(x::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1) where A<:AbstractArray - pad_, stride_ = padtuple(x, pad), padtuple(x, stride) - if size === nothing - size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_) - end - crosscor!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation) -end - -function ∇conv_data(dy::AbstractArray, w::AbstractArray; size=nothing, pad = 0, stride = 1, dilation = 1, flipkernel = 0) - pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation) - if size === nothing - size = ctdims(Base.size(dy), Base.size(w), pad_, stride_, dilation_) - end - ∇conv_data!(similar(dy, size), dy, w, pad = pad_, stride = stride_, dilation = dilation_, flipkernel=flipkernel) -end - -function ∇conv_filter(dy::AbstractArray, x::AbstractArray; size = nothing, pad = 0, stride = 1, dilation = 1, flipkernel=0) - pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation) - if size === nothing - size = wdims(Base.size(x), Base.size(dy), pad_, stride_, dilation_) - end - ∇conv_filter!(zero(similar(dy, size)), dy, x; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel) end -# N-D dispatch - -function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3}; - pad = 0, stride = 1, dilation = 1, flipkernel =0) where T - args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w)) - conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel) - return y -end - -function crosscor!(y::AbstractArray, x::AbstractArray, w::AbstractArray; - pad = 0, stride = 1, dilation = 1) - conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1) -end - -function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3}, x::AbstractArray{T,3}; - pad = 0, stride = 1, dilation = 1, flipkernel=0) where T - args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x)) - ∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel) - return dw -end - -function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3}, w::AbstractArray{T,3}; - pad = 0, stride = 1, dilation = 1, flipkernel = 0) where T - args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, w)) - ∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1), flipkernel = flipkernel) - return dx -end - -conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = - conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) - -∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}; - pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = - conv2d_grad_w!(dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) - -∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = - conv2d_grad_x!(dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) - -conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; - pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = - conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) - -∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}; - pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = - conv3d_grad_w!(dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) - -∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5}; - pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = - conv3d_grad_x!(dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) - - # Depthwise Conv - -function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride) - ((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4]) -end - -function depthwiseconv(x::AbstractArray, w::AbstractArray; pad = 0, stride = 1) - pad_, stride_ = padtuple(x, pad), padtuple(x, stride) - depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_) -end - -function depthwisecrosscor(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray - pad_, stride_ = padtuple(x, pad), padtuple(x, stride) - depthwisecrosscor!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_) -end - -depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1, flipkernel=0) where T = - depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode= flipkernel) - -depthwisecrosscor!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1) where T = - depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1) - -∇depthwiseconv_data(dy::AbstractArray, x::AbstractArray, w::AbstractArray; pad = 0, stride = 1, flipkernel=0) = - ∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel) - -∇depthwiseconv_filter(dy::AbstractArray, x::AbstractArray, w::AbstractArray; pad = 0, stride = 1, flipkernel=0) = - ∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel) - -∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1, flipkernel=0) where T = - depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, mode=flipkernel) +# Finally, let's generate auto-allocating versions of all our functions, for all backends. +# We `@timeit` these methods separately, as we want to know how much time is spent in +# allocation. :P +for backend in (Symbol(), :_direct, :_im2col) + # First make auto-allocating versions of the conv()-like calls: + for name in (:conv, :depthwiseconv) + @eval begin + @timeit_debug to function $(Symbol("$(name)$(backend)"))( + x::AbstractArray{xT,N}, w::AbstractArray{wT,N}, + cdims::ConvDims; kwargs...) where {xT, wT, N} + y = similar(x, promote_type(xT, wT), output_size(cdims)..., + channels_out(cdims), size(x,N)) + return $(Symbol("$(name)$(backend)!"))(y, x, w, cdims; kwargs...) + end + end + end -∇depthwiseconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1, flipkernel=0) where T = - depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, mode=flipkernel) + for name in (:∇conv_data, :∇depthwiseconv_data) + @eval begin + @timeit_debug to function $(Symbol("$(name)$(backend)"))( + dy::AbstractArray{yT,N}, w::AbstractArray{wT,N}, + cdims::ConvDims; kwargs...) where {yT, wT, N} + dx = similar(dy, input_size(cdims)..., channels_in(cdims), + size(dy, N)) + return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...) + end + end + end -# Pooling + # We do the conv/depthwiseconv filter backprops separately, as the shape calculation + # for `w` is slightly different for depthwise than for normal dense convolution. + @eval begin + @timeit_debug to function $(Symbol("∇conv_filter$(backend)"))( + x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, + cdims::ConvDims; kwargs...) where {xT, yT, N} + dw = similar(dy, kernel_size(cdims)..., channels_in(cdims), + channels_out(cdims)) + return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...) + end + end -function pdims(dims::Dims{N}, window, padding, stride) where N - ntuple(Val(N)) do i - if i < N-1 - 1 + (dims[i] + 2*padding[i] - window[i])÷stride[i] - else - dims[i] + @eval begin + @timeit_debug to function $(Symbol("∇depthwiseconv_filter$(backend)"))( + x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, + cdims::ConvDims; kwargs...) where {xT, yT, N} + dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims), + channels_in(cdims)) + return $(Symbol("∇depthwiseconv_filter$(backend)!"))(dw, x, dy, cdims; + kwargs...) + end end - end end - -expand(::Type{Val{N}}, i::Integer) where N = ntuple(_ -> i, Val(N)) -expand(::Type{Val{N}}, i::NTuple{N, Integer}) where N = i - -# Interface - -maxpool(x::AbstractArray, k; pad = map(_->0,k), stride = k) = - maxpool!(similar(x, pdims(size(x), k, expand(Val{length(k)}, pad), - expand(Val{length(k)}, stride))), x, k, pad = expand(Val{length(k)}, pad), - stride = expand(Val{length(k)}, stride)) - -maxpool!(y::A, x::A, k; kw...) where A<:AbstractArray = - maxpool_cpu!(y, x, k; kw...) - -∇maxpool(dy::A, y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray = - ∇maxpool!(similar(x), dy, y, x, k, pad = expand(Val{length(k)}, pad), - stride = expand(Val{length(k)}, stride)) - -∇maxpool!(dx::A, dy::A, y::A, x::A, k; kw...) where A<:AbstractArray = - ∇maxpool_cpu!(dx, dy, y, x, k; kw...) - -meanpool(x::AbstractArray, k; pad = map(_->0,k), stride = k) = - meanpool!(similar(x, pdims(size(x), k, expand(Val{length(k)}, pad), - expand(Val{length(k)}, stride))), x, k, pad = expand(Val{length(k)}, pad), - stride = expand(Val{length(k)}, stride)) - -meanpool!(y::A, x::A, k; kw...) where A<:AbstractArray = - meanpool_cpu!(y, x, k; kw...) - -∇meanpool(dy::A, y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray = - ∇meanpool!(similar(x), dy, y, x, k, pad = expand(Val{length(k)}, pad), - stride = expand(Val{length(k)}, stride)) - -∇meanpool!(dx::A, dy::A, y::A, x::A, k; kw...) where A<:AbstractArray = - ∇meanpool_cpu!(dx, dy, y, x, k; kw...) - -# N-D dispatch -# We use a separate function to avoid ambiguity issues -# (more specific array types vs. more specific dimensions) - -maxpool_cpu!(y::A, x::A, k::Dims{2}; pad = (0,0), stride = k) where A<:AbstractArray{<:Real,4} = - maxpool2d!(y, x, window = k, padding = pad, stride = stride) - -∇maxpool_cpu!(dx::AbstractArray{<:Real,4}, dy::AbstractArray{<:Real,4}, y::AbstractArray{<:Real,4}, x::AbstractArray{<:Real,4}, - k::Dims{2}; pad = (0,0), stride = k) = - maxpool2d_grad!(dx, dy, y, x, - window = k, padding = pad, stride = stride) - -maxpool_cpu!(y::AbstractArray{<:Real,5}, x::AbstractArray{<:Real,5}, k::Dims{3}; pad = (0,0), stride = k) = - maxpool3d!(y, x, window = k, padding = pad, stride = stride) - -∇maxpool_cpu!(dx::AbstractArray{<:Real,5}, dy::AbstractArray{<:Real,5}, y::AbstractArray{<:Real,5}, x::AbstractArray{<:Real,5}, - k::Dims{3}; pad = (0,0), stride = k) = - maxpool3d_grad!(dx, dy, y, x, - window = k, padding = pad, stride = stride) - -meanpool_cpu!(y::AbstractArray{<:Real,4}, x::AbstractArray{<:Real,4}, k::Dims{2}; pad = (0,0), stride = k) = - meanpool2d!(y, x, window = k, padding = pad, stride = stride) - -∇meanpool_cpu!(dx::AbstractArray{<:Real,4}, dy::AbstractArray{<:Real,4}, y::AbstractArray{<:Real,4}, x::AbstractArray{<:Real,4}, - k::Dims{2}; pad = (0,0), stride = k) = - meanpool2d_grad!(dx, dy, y, x, - window = k, padding = pad, stride = stride) - -meanpool_cpu!(y::AbstractArray{<:Real,5}, x::AbstractArray{<:Real,5}, k::Dims{3}; pad = (0,0), stride = k) = - meanpool3d!(y, x, window = k, padding = pad, stride = stride) - -∇meanpool_cpu!(dx::AbstractArray{<:Real,5}, dy::AbstractArray{<:Real,5}, y::AbstractArray{<:Real,5}, x::AbstractArray{<:Real,5}, - k::Dims{3}; pad = (0,0), stride = k) = - meanpool3d_grad!(dx, dy, y, x, - window = k, padding = pad, stride = stride) - -# Deprecated - -# 0.4.2 -@deprecate ∇conv_data(dy::A, x::A, w::A; kw...) where A<:AbstractArray ∇conv_data(dy, w; size=size(x), kw...) -@deprecate ∇conv_filter(dy::A, x::A, w::A; kw...) where A<:AbstractArray ∇conv_filter(dy, x; size=size(w), kw...) diff --git a/src/cubroadcast.jl b/src/cubroadcast.jl deleted file mode 100644 index 34e002a91..000000000 --- a/src/cubroadcast.jl +++ /dev/null @@ -1,2 +0,0 @@ -# Kept only for backwards compatibility -macro fix(ex) esc(ex) end diff --git a/src/dim_helpers.jl b/src/dim_helpers.jl new file mode 100644 index 000000000..d02ec551c --- /dev/null +++ b/src/dim_helpers.jl @@ -0,0 +1,122 @@ +# Various helper functions to calculate dimensions for operations +include("dim_helpers/ConvDims.jl") +include("dim_helpers/DenseConvDims.jl") +include("dim_helpers/DepthwiseConvDims.jl") +include("dim_helpers/PoolDims.jl") + + +""" + transpose_swapbatch(x::AbstractArray) + +Given an AbstractArray, swap its batch and channel axes, as we must during transposed +convolution. We do this to the operands during convolution, and then again to the +output once we're done. +""" +function transpose_swapbatch(x::AbstractArray) + return permutedims(x, ((1:(ndims(x)-2))..., ndims(x), ndims(x)-1)) +end +function transpose_swapbatch(x::Tuple) + return (x[1:end-2]..., x[end], x[end-1]) +end + +""" + transpose_pad(cdims::ConvDims) + +Transposed convolution can be calculated in terms of typical convolution with some extra +padding. This method computes the padding of the convolution that would result in the +transposed convolution of two operands, in essence taking care of that "extra padding". +Note that this method should almost always be accompanied by a call that predilates one +of the operands. +""" +function transpose_pad(cdims::ConvDims) + I = input_size(cdims) + K = kernel_size(cdims) + D = dilation(cdims) + P = padding(cdims) + S = stride(cdims) + return ntuple(length(P)) do i + hi = ceil(Int, i/2) + if mod(i, 2) == 1 + return (K[hi] - 1)*D[hi] - P[i] + else + return (K[hi] - 1)*D[hi] - P[i] + mod(I[hi] + P[i-1] + P[i] - (K[hi] - 1)*D[hi] - 1, S[hi]) + end + end +end + +""" + insert_singleton_spatial_dimension(cdims::DenseConvDims) + +When converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton +spatial dimension at the end of the spatial dimensions. This does so for a ConvDims. +""" +@inline function insert_singleton_spatial_dimension(cdims::C) where {C <: ConvDims} + return basetype(C)(cdims; + N=spatial_dims(cdims) + 1, + I=(input_size(cdims)..., 1), + K=(kernel_size(cdims)..., 1), + S=(stride(cdims)..., 1), + # Padding is always the problem child.... + P=(padding(cdims)..., 0, 0), + D=(dilation(cdims)..., 1), + ) +end + +@inline function insert_singleton_spatial_dimension(x::AbstractArray) + return reshape(x, size(x)[1:end-2]..., 1, size(x)[end-1:end]...) +end + +# Helper to do this multiple times +@inline function insert_singleton_spatial_dimension(x, reps::Int) + for r in 1:reps + x = insert_singleton_spatial_dimension(x) + end + return x +end + +""" + predilated_size(x_size::Tuple, dilation::Tuple) + +Calculate the size of a predilated `x` given a particular dilation factor. This is used +within `predilate()` and `transpose_cdims()`. +""" +function predilated_size(x_size::NTuple{N}, dilation::NTuple{M}) where {N, M} + @assert (M == N - 2) DimensionMismatch("len(dilation) != number of spatial dims") + return ntuple(N) do idx + if idx <= N - 2 + return (x_size[idx] - 1)*dilation[idx] + 1 + else + x_size[idx] + end + end +end + +""" + predilate(x, dilation::Tuple) + +Places elements of `x` within a lattice of zeros, used in expressing a transposed +convolution in terms of normal convolution. Note that while we call this "predilation" +for aesthetic reasons, you are typically passing a "stride" value into here. Yes, +transposed convolution is confusing. +""" +function predilate(x::AbstractArray{T,N}, dilation::NTuple{M}) where {T, N, M} + @assert (M == N - 2) DimensionMismatch("len(dilation) != number of spatial dims") + + # If there is no dilation to be done, then ignore it. + if all(dilation .== 1) + return x + end + + # Validate dilation factors + for idx in 1:length(dilation) + @assert dilation[idx] >= 1 ArgumentError("dilation cannot be less than 1") + end + + # Create new x that is bigger and holier + x_dil = zeros(eltype(x), predilated_size(size(x), dilation)) + + # Fill in strategic locations within `x_dil`, such that there are `dilation[idx] - 1` + # zeros between each element of `x` along each spatial dimension. + x_dil[(1:dilation[idx]:size(x_dil,idx) for idx in 1:(N-2))..., :, :] .= x + return x_dil +end \ No newline at end of file diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl new file mode 100644 index 000000000..335cf4389 --- /dev/null +++ b/src/dim_helpers/ConvDims.jl @@ -0,0 +1,122 @@ +export ConvDims + +""" + ConvDims + +Type system-level information about convolution dimensions. Critical for things like +`im2col!()` to generate efficient code, and helpful to reduce the number of kwargs +getting passed around. + +We don't want to specialize on things like image size/channel count, so we generally +store those as fields, just for convenience, and to allow for non-breaking changes when +we decide we _do_ want to specialize on those values. We always want to specialize on +things like stride, padding, dilation, and kernel flipping though. +""" +abstract type ConvDims{N, S, P, D, F} end + +# Hack to get rid of type parameters +function basetype(::Type{C}) where {C <: ConvDims} + if C <: DepthwiseConvDims + return DepthwiseConvDims + elseif C <: DenseConvDims + return DenseConvDims + elseif C <: PoolDims + return PoolDims + else + return nothing + end +end + +# Obvious getter definitions for the type system-level definitions +spatial_dims(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = N +stride(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = S +padding(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = P +dilation(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = D +flipkernel(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = F + +""" + im2col_dims(c::ConvDims) + +im2col calculates, for each output pixel, the "convolution" of N kernels where N is the +number of output channels, by doing a matrix multiply. The dimensions of that matrix +are given by this function. +""" +im2col_dims(c::ConvDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c)) + +# Protect your skin, kids. Also do common validation of stride, padding, etc... +function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N} + # Number of spatial dimensions in `x` and `w`. + nd = N - 2 + + # Given a number, duplicate it out to have `nd` length. If it's already a collection, + # just splat it out into a tuple so it's always a tuple. We'll lint length later. + expand_size(p::Number) = ntuple(_ -> Int(p), nd) + expand_size(p) = tuple(p...) + + # Convert stride, padding, dilation, etc.. to fully-specified tuples + pstride = expand_size(stride) + pdilation = expand_size(dilation) + ppadding = expand_size(padding) + + if length(pstride) != nd + throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!")) + end + if length(pdilation) != nd + throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!")) + end + + # padding is kind of a special case; we allow it to be either 2-length or 4-length, + # since we support asymmetrical padding + if length(ppadding) != 2*nd + if length(ppadding) == nd + # Do this repeat dance so that we get lo/hi symmetrical padding + ppadding = tuple(repeat(collect(ppadding), inner=2)...) + else + throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!")) + end + end + + # Assert that kernel size * dilation is <= padded input size + for idx in 1:nd + Is = x_size[idx] + Pl = ppadding[(idx - 1)*2 + 1] + Ph = ppadding[(idx - 1)*2 + 2] + Ks = w_size[idx] + Ds = pdilation[idx] + if Is + Pl + Ph < (Ks - 1)*Ds + 1 + throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!")) + end + end + + return pstride, ppadding, pdilation +end + +""" + output_size(c::ConvDims) + +Calculate the output (spatial) dimensions of the convolution. Get channel count via +`channels_out(c)`, and batch count is unknowable. +""" +function output_size(c::ConvDims) + I = input_size(c) + K = kernel_size(c) + S = stride(c) + P = padding(c) + D = dilation(c) + + return ntuple(spatial_dims(c)) do i + return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1 + end +end + +# Override show() for these beauties +function Base.show(io::IO, cdims::C) where {C <: ConvDims} + I = (input_size(cdims)..., channels_in(cdims)) + O = (output_size(cdims)..., channels_out(cdims)) + K = kernel_size(cdims) + S = stride(cdims) + P = padding(cdims) + D = dilation(cdims) + F = flipkernel(cdims) + print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F") +end diff --git a/src/dim_helpers/DenseConvDims.jl b/src/dim_helpers/DenseConvDims.jl new file mode 100644 index 000000000..df559e6fa --- /dev/null +++ b/src/dim_helpers/DenseConvDims.jl @@ -0,0 +1,77 @@ +export DenseConvDims + +""" + DenseConvDims + +Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d. +""" +struct DenseConvDims{N,K,C_in,C_out,S,P,D,F} <: ConvDims{N,S,P,D,F} + I::NTuple{N,Int} +end + +# Getters for the fields +input_size(c::DenseConvDims) = c.I +kernel_size(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = K +channels_in(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_in +channels_out(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_out + +# Convenience wrapper to create DenseConvDims objects +function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; + stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M + # Do common parameter validation + stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation) + + # Ensure channels are equal + if x_size[end-1] != w_size[end-1] + xs = x_size[end-1] + ws = w_size[end-1] + throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) + end + + # The type parameters are what + return DenseConvDims{ + M - 2, + w_size[1:end-2], + x_size[end-1], + w_size[end], + stride, + padding, + dilation, + flipkernel + }( + # Input spatial size + x_size[1:end-2], + ) +end + +# Auto-extract sizes and sub out to big brother above +function DenseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) + if ndims(x) != ndims(w) + throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) + end + return DenseConvDims(size(x), size(w); kwargs...) +end + +# Useful for constructing a new DenseConvDims that has only a few elements different +# from the original progenitor object that it inherits shapes from. +function DenseConvDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), + C_in=channels_in(c), C_out=channels_out(c), S=stride(c), + P=padding(c), D=dilation(c), F=flipkernel(c)) + return DenseConvDims{N, K, C_in, C_out, S, P, D, F}(I) +end + +function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M} + # First, check that channel counts are all correct: + @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))") + @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))") + @assert w[end-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end-1]) vs. $(channels_in(cdims)))") + @assert w[end] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[end]) vs. $(channels_out(cdims)))") + + # Next, check that the spatial dimensions match up + @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))") + @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))") + @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))") + + # Finally, check that the batch size matches + @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))") +end diff --git a/src/dim_helpers/DepthwiseConvDims.jl b/src/dim_helpers/DepthwiseConvDims.jl new file mode 100644 index 000000000..a0555ff58 --- /dev/null +++ b/src/dim_helpers/DepthwiseConvDims.jl @@ -0,0 +1,90 @@ +export DepthwiseConvDims + +""" + DepthwiseConvDims + +Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to +characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from +DenseConvDims primarily for channel calculation differences. +""" +struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F} + I::NTuple{N, Int} + K::NTuple{N, Int} + C_in::Int + C_mult::Int +end + +# Getters for the fields +input_size(c::DepthwiseConvDims) = c.I +kernel_size(c::DepthwiseConvDims) = c.K +channels_in(c::DepthwiseConvDims) = c.C_in +channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c) +channel_multiplier(c::DepthwiseConvDims) = c.C_mult + + +# Convenience wrapper to create DepthwiseConvDims objects +function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; + stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M + # Do common parameter validation + stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation) + + # Ensure channels are equal + if x_size[end-1] != w_size[end] + xs = x_size[end-1] + ws = w_size[end] + throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) + end + + return DepthwiseConvDims{ + M - 2, + stride, + padding, + dilation, + flipkernel + }( + # Image spatial size + x_size[1:end-2], + + # Kernel spatial size + w_size[1:end-2], + + # Input channels + x_size[end-1], + + # Channel multiplier + w_size[end-1], + ) +end + +# Auto-extract sizes and just pass those directly in +function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) + if ndims(x) != ndims(w) + throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) + end + return DepthwiseConvDims(size(x), size(w); kwargs...) +end + +# Useful for constructing a new DepthwiseConvDims that has only a few elements different +# from the original progenitor object. +function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), + C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c), + P=padding(c), D=dilation(c), F=flipkernel(c)) + return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m) +end + +# This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count +function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M} + # First, check that channel counts are all correct: + @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))") + @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))") + @assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))") + @assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))") + + # Next, check that the spatial dimensions match up + @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))") + @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))") + @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))") + + # Finally, check that the batch size matches + @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))") +end \ No newline at end of file diff --git a/src/dim_helpers/PoolDims.jl b/src/dim_helpers/PoolDims.jl new file mode 100644 index 000000000..97144968e --- /dev/null +++ b/src/dim_helpers/PoolDims.jl @@ -0,0 +1,72 @@ +export PoolDims + +""" + PoolDims + +Dimensions for a "pooling" operation that can have an arbitrary input size, kernel size, +stride, dilation, and channel count. Used to dispatch onto efficient implementations at +compile-time. +""" +struct PoolDims{N,K,S,P,D} <: ConvDims{N, S, P, D, false} + I::NTuple{N,Int} + C_in::Int +end + +# Getters for both type parameters and fields +kernel_size(c::PoolDims{N,K,S,P,D}) where {N, K, S, P, D} = K +input_size(c::PoolDims) = c.I +channels_in(c::PoolDims) = c.C_in +channels_out(c::PoolDims) = c.C_in + + +# Convenience wrapper to create DenseConvDims objects +function PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int}; + stride=k, padding=0, dilation=1) where {M, L} + # Expand `k` up to a tuple + if typeof(k) <: Number + k = ntuple(_ -> k, M - 2) + end + + # Do common parameter validation + stride, padding, dilation = check_spdf(x_size, (k..., 1, 1), stride, padding, dilation) + + # Build it + return PoolDims{ + M - 2, + k, + stride, + padding, + dilation + }( + # Image spatial size + x_size[1:end-2], + + # Input channels + x_size[end-1], + ) +end + +# Auto-take `size(x)` when `x` is an array. +function PoolDims(x::AbstractArray, k; kwargs...) + return PoolDims(size(x), k; kwargs...) +end + +# Useful for constructing a new PoolDims that has only a few elements different +# from the original progenitor object that it inherits shapes from. +function PoolDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), + C_in=channels_in(c), S=stride(c), P=padding(c), D=dilation(c)) + return PoolDims{N, K, S, P, D}(I, C_in) +end + +function check_dims(x::NTuple{M}, y::NTuple{M}, pdims::PoolDims) where {M} + # First, check that channel counts are all correct: + @assert x[end-1] == channels_in(pdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(pdims)))") + @assert y[end-1] == channels_out(pdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(pdims)))") + + # Next, check that the spatial dimensions match up + @assert x[1:end-2] == input_size(pdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(pdims)))") + @assert y[1:end-2] == output_size(pdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(pdims)))") + + # Finally, check that the batch size matches + @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))") +end diff --git a/src/gemm.jl b/src/gemm.jl new file mode 100644 index 000000000..e9d2a8461 --- /dev/null +++ b/src/gemm.jl @@ -0,0 +1,58 @@ +## Low level gemm! call with pointers +## Borrowed from Knet.jl, adapted for compile-time constants + +using LinearAlgebra +using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc + +""" + gemm!() + +Low-level gemm!() call with pointers, borrowed from Knet.jl + +Calculates `C = alpha*op(A)*op(B) + beta*C`, where: + - `transA` and `transB` set `op(X)` to be either `identity()` or `transpose()` + - alpha and beta are scalars + - op(A) is an (M, K) matrix + - op(B) is a (K, N) matrix + - C is an (M, N) matrix. +""" +gemm! + +# These are the datatypes we have fast GEMM for +gemm_datatype_mappings = ( + (:dgemm_, Float64), + (:sgemm_, Float32), + (:zgemm_, ComplexF64), + (:cgemm_, ComplexF32), +) +for (gemm, elt) in gemm_datatype_mappings + @eval begin + @inline function gemm!(transA::Val, transB::Val, + M::Int, N::Int, K::Int, + alpha::$(elt), A::Ptr{$elt}, B::Ptr{$elt}, + beta::$(elt), C::Ptr{$elt}) + # Convert our compile-time transpose marker to a char for BLAS + convtrans(V::Val{false}) = 'N' + convtrans(V::Val{true}) = 'T' + + if transA == Val(false) + lda = M + else + lda = K + end + if transB == Val(false) + ldb = K + else + ldb = N + end + ldc = M + ccall((@blasfunc($(gemm)), libblas), Nothing, + (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, + Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, + Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, + Ref{BlasInt}), + convtrans(transA), convtrans(transB), M, N, K, + alpha, A, lda, B, ldb, beta, C, ldc) + end + end +end diff --git a/src/impl/conv.jl b/src/impl/conv.jl deleted file mode 100644 index 3523199f6..000000000 --- a/src/impl/conv.jl +++ /dev/null @@ -1,556 +0,0 @@ -# convert padding etc. size to an Int array of the right dimension -function psize(p, x) - nd = ndims(x)-2 - if isa(p,Number) - ntuple(_->Int(p), nd) - elseif length(p)==nd - tuple(p...) - else - throw(DimensionMismatch("psize: $p $nd")) - end -end - -# Type system-level information about convolution dimensions. Critical for things like -# im2col_2d!() to generate efficient code. -struct ConvDims{img, kernel, channels, stride, padding, dilation, flipkernel} end -img_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = I - -# Calculate the output dimensions of this convolution -function output_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} - O_w = div(I[1] + P[1] + P[2] - (K[1] - 1) * D[1] - 1, S[1]) + 1 - O_h = div(I[2] + P[3] + P[4] - (K[2] - 1) * D[2] - 1, S[2]) + 1 - return (O_w, O_h) -end -kernel_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = K -img_channels(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = C -stride(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = S -padding(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = P -dilation(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = D -flipkernel(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = F - -function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, cdims::ConvDims) where T - width, height = img_size(cdims) - kernel_w, kernel_h = kernel_size(cdims) - channels = img_channels(cdims) - pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi = padding(cdims) - dil_w, dil_h = dilation(cdims) - stride_w, stride_h = stride(cdims) - width_col, height_col = output_size(cdims) - - if flipkernel(cdims) - flipk = (w, h) -> (kernel_w - w + 1, kernel_h - h + 1) - else - flipk = (w, h) -> (w, h) - end - - # Reshape col for easy access. - col_reshaped = reshape(col, (width_col, height_col, kernel_w, kernel_h, channels)) - - # Let us first calculate the number of rows/columns within which we must zero out some - # portion of the image patches we're copying over. Note the subtractions on the `_hi` - # variants are due to us needing to account for padding that is completely ignored due - # to stride/dilation/kernel size combinations. - spill_w_lo = ceil(Int, pad_w_lo/stride_w) - spill_w_hi = width_col - div(width + pad_w_lo - (kernel_w - 1)*dil_w, stride_w) - spill_h_lo = ceil(Int, pad_h_lo/stride_h) - spill_h_hi = height_col - div(height + pad_h_lo - (kernel_h - 1)*dil_h, stride_h) - spill_w_hi_abs = width_col - spill_w_hi + 1 - spill_h_hi_abs = height_col - spill_h_hi + 1 - - # First, a helper function to project from output (w, h) to input (input_w, input_h) - project(idx, stride, pad) = (idx - 1)*stride - pad + 1 - - # These are the regions we're going to have to run with cognizance of padding - padded_regions = ( - (1:width_col, 1:spill_h_lo), - (1:spill_w_lo, (spill_h_lo+1):(spill_h_hi_abs-1)), - (spill_w_hi_abs:width_col, (spill_h_lo+1):(spill_h_hi_abs-1)), - (1:width_col, spill_h_hi_abs:height_col), - ) - - # We begin by copying the central region of the image which requires no padding at all. - # Eliminating the branches of the fully generalized version below gives us a nice - # speedup on the majority of the data. - for c in 1:channels - for kh in 1:kernel_h - for kw in 1:kernel_w - for h in (spill_h_lo+1):(height_col - spill_h_hi) - input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h - - @inbounds for w in (spill_w_lo+1):(width_col - spill_w_hi) - input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w - col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c] - end - end - end - end - end - - # For each "padded region", we run the fully general version - for (w_region, h_region) in padded_regions - for c in 1:channels - for kh in 1:kernel_h - for kw in 1:kernel_w - @inbounds for h in h_region - input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h - - # If this column is off the edge, then deal with the entire thing - # in one fell swoop, like a ravenous flock of crows. CAW CAW. - if input_kh <= 0 || input_kh > height - for w in w_region - col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped)) - end - continue - end - - @inbounds for w in w_region - input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w - - # If this pixel is off the edge of the map, clear it out. - if input_kw <= 0 || input_kw > width - col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped)) - continue - end - - # Copy the data over - col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c] - end - end - end - end - end - end -end - -function col2im_2d!(col::AbstractArray{T,2}, img::AbstractArray{T,3}, width::Int, height::Int, - channels::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, - stride_h::Int, dil_w::Int, dil_h::Int, mode::Int) where T - - height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1 - width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1 - channels_col = channels * kernel_h * kernel_w - - fill!(img, 0) - #pragma omp parallel for - for c = 1:channels_col - w_offset = (c - 1) % kernel_w - h_offset = div(c - 1, kernel_w) % kernel_h - c_im = div(c - 1, kernel_h * kernel_w) - if mode == 0 - w_offset = kernel_w - 1 - w_offset - h_offset = kernel_h - 1 - h_offset - end - for h = 1:height_col, w = 1:width_col - h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h - w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w - if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width - cval::T = col[((c - 1) * height_col + h - 1) * width_col + w] - img[(c_im * height + h_pad) * width + w_pad + 1] += cval - end - end - end -end - -function im2col_3d!(img::AbstractArray{T,4}, col::AbstractArray{T,2}, width::Int, height::Int, depth::Int, - channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, pad_w::Int, pad_h::Int, pad_d::Int, - stride_w::Int, stride_h::Int, stride_d::Int, dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T - - height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1 - width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1 - depth_col = div(depth + 2pad_d - (kernel_d - 1) * dil_d - 1, stride_d) + 1 - channels_col = channels * kernel_h * kernel_w * kernel_d - - - #pragma omp parallel for - for c = 1:channels_col - w_offset = (c - 1) % kernel_w - h_offset = div(c - 1, kernel_w) % kernel_h - d_offset = div(c - 1, kernel_w * kernel_h) % kernel_d - c_im = div(c - 1, kernel_w * kernel_h * kernel_d) - if mode == 0 - w_offset = kernel_w - 1 - w_offset - h_offset = kernel_h - 1 - h_offset - d_offset = kernel_d - 1 - d_offset - end - for d = 1:depth_col, h = 1:height_col, w = 1:width_col - d_pad = (d - 1) * stride_d - pad_d + d_offset * dil_d - h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h - w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w - if d_pad >= 0 && d_pad < depth && h_pad >= 0 && h_pad < height && - w_pad >= 0 && w_pad < width - col[(((c - 1) * depth_col + d - 1) * height_col + h - 1) * width_col + w] = - img[((c_im * depth + d_pad) * height + h_pad) * width + w_pad + 1] - else - col[(((c - 1) * depth_col + d - 1) * height_col + h - 1) * width_col + w] = 0 - end - end - end -end - -function col2im_3d!(col::AbstractArray{T,2}, img::AbstractArray{T,4}, width::Int, height::Int, - depth::Int, channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, - pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int, - dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T - - height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1 - width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1 - depth_col = div(depth + 2pad_d - (kernel_d - 1) * dil_d - 1, stride_d) + 1 - channels_col = channels * kernel_h * kernel_w * kernel_d - - fill!(img, 0) - #pragma omp parallel for - for c = 1:channels_col - w_offset = (c - 1) % kernel_w; - h_offset = div(c - 1, kernel_w) % kernel_h - d_offset = div(c - 1, kernel_w * kernel_h) % kernel_d - c_im = div(c - 1, kernel_h * kernel_w * kernel_d) - - if mode == 0 - w_offset = kernel_w - 1 - w_offset - h_offset = kernel_h - 1 - h_offset - d_offset = kernel_d - 1 - d_offset - end - - for d = 1:depth_col, h = 1:height_col, w = 1:width_col - d_pad = (d - 1) * stride_d - pad_d + d_offset * dil_d - h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h - w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w - if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width && - d_pad >= 0 && d_pad < depth - cval::T = col[(((c - 1) * depth_col + d - 1) * height_col + h - 1) * width_col + w] - iidx = ((c_im * depth + d_pad) * height + h_pad) * width + w_pad + 1 - #pragma omp atomic - img[iidx] += cval - end - end - end -end - -function dilation_dims(w, dilation = 1) - N = ndims(w) - dims_w = size(w) - dil = psize(dilation, w) - ntuple(N) do i - if i < N - 1 - (dims_w[i] - 1) * dil[i] + 1 - else - dims_w[i] - end - end -end - -function im2col_dims(w,y) - N = ndims(y) - r,c = 1,1 - for i=1:N-2 - r *= size(y,i) - c *= size(w,i) - end - c *= size(w,N-1) - return (r, c) -end - -function im2col_dims(w::NTuple{4, Int}, y) - N = ndims(y) - r,c = 1,1 - for i=1:N-2 - r *= size(y,i) - c *= w[i] - end - c *= w[N-1] - return (r, c) -end - -function depthwiseconv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - padding = 0, stride = 1, mode = 0, alpha = T(1)) where T - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier - @assert Cx == Cw DimensionMismatch() - Wy,Hy,Cy,Ny = size(y) # Cy = Cw * Cm - dims_w = (Ww,Hw,Cw,Cm*Cw) - x2dims = im2col_dims(dims_w,y) - x2 = similar(x, x2dims) - (p1,p2) = psize(padding,x) - (s1,s2) = psize(stride,x) - M,N,K,Y = Wy*Hy,Cm,Ww*Hw,Wy*Hy*Cm - yidx = 1 - @inbounds for i in 1:Nx - im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, mode) - @inbounds for j in 1:Cx - gemm!('N','N',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(w,(j-1)*K*N+1),T(0),pointer(y,yidx)) - yidx += Y - end - end - return y -end - -function depthwiseconv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4}; - padding=0, stride=1, mode=0, alpha=1) where T - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier - @assert Cx == Cw DimensionMismatch() - Wy,Hy,Cy,Ny = size(dy) # Cy = Cw * Cm - @assert Cy == Cw * Cm DimensionMismatch() - dims_w = (Ww,Hw,Cw,Cm*Cw) - x2dims = im2col_dims(dims_w,dy) - x2 = similar(x, x2dims) - (p1,p2) = psize(padding,x) - (s1,s2) = psize(stride,x) - M,N,K,Y,W = Ww*Hw,Cm,Wy*Hy,Wy*Hy*Cm*Cx,Ww*Hw*Cm - alpha,beta = T(alpha),T(1) - dyidx = 1 - @inbounds for i in 1:Nx - im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, mode) - dwidx = 1 - @inbounds for j in 1:Cx - gemm!('T','N',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(dy,dyidx+(j-1)*K*N),beta,pointer(dw,dwidx)) - dwidx += W - end - dyidx += Y - end - return dw -end - -function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4}; - padding=0, stride=1, mode=0, alpha=1) where T - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier - @assert Cx == Cw DimensionMismatch() - Wy,Hy,Cy,Ny = size(dy) # Cy = Cw * Cm - @assert Cy == Cw * Cm DimensionMismatch() - dims_w = (Ww,Hw,Cw,Cm*Cw) - x2dims = im2col_dims(dims_w,dy) - x2 = similar(x, x2dims) - M,N,K,Y,W = Wy*Hy,Ww*Hw,Cm,Wy*Hy*Cm*Cx,Ww*Hw*Cm - alpha,beta = T(alpha),T(0) - (p1,p2) = psize(padding,x) - (s1,s2) = psize(stride,x) - dyidx = 1 - @inbounds for i in 1:Nx - @inbounds for j in 1:Cx - gemm!('N','T',M,N,K,alpha,pointer(dy,dyidx+(j-1)*K*M),pointer(w,(j-1)*K*N+1),beta,pointer(x2,(j-1)*M*N+1)) - end - col2im2d!(dims_w,dx,x2,i,p1,p2,s1,s2,mode) - dyidx += Y - end - return dx -end - -function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, - cdims::ConvDims; alpha=T(1)) where T - Wx, Hx = img_size(cdims) - Ww, Hw = kernel_size(cdims) - Wy, Hy = output_size(cdims) - Cx = img_channels(cdims) - M, N, K, Y = Wy*Hy, size(y,3), prod(size(w)[1:3]), prod(size(y)[1:3]) - - x2 = similar(x, im2col_dims(w, y)) - @inbounds for n in 1:size(x,4) - im2col_2d!(view(x, :, :, :, n), x2, cdims) - gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,(n - 1)*Y + 1)) - end - return y -end - -function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - padding=0, stride=1, dilation=1, mode=0, alpha=T(1)) where T - if mode != 0 && mode != 1 - throw(ArgumentError("conv2d only supports mode=0 or 1.")) - end - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,C1,C2 = size(w) - - # Check that the number of channels in `x` matches the number of channels in each - # kernel of `w`. IF it doesn't, throw a DimensionMismatch() - if Cx != C1 - throw(DimensionMismatch()) - end - (p1,p2) = psize(padding,x) - (s1,s2) = psize(stride,x) - (d1,d2) = psize(dilation, x) - - cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}() - return conv2d!(y, x, w, cdims; alpha=alpha) -end - -function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, dy::AbstractArray{T,4}; - padding=0, stride=1, dilation=1, mode=0, alpha=1) where T - # dw = x'*dy - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,C1,C2 = size(dw) - Wy,Hy,Cy,Ny = size(dy) - # if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end - @assert Cx==C1 && Cy==C2 && Ny==Nx - x2dims = im2col_dims(dw,dy) - x2 = similar(x, x2dims) - # op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix. - Y,M,N,K = Wy*Hy*Cy,Ww*Hw*Cx,Cy,Wy*Hy - alpha,beta = T(alpha),T(1) - (p1,p2) = psize(padding,x) - (s1,s2) = psize(stride,x) - (d1,d2) = psize(dilation,x) - dyi = 1 - @inbounds for n in 1:Nx - im2col2d!(dw, x, x2, n, p1, p2, s1, s2, d1, d2, mode) - gemm!('T','N',M,N,K,alpha,pointer(x2),pointer(dy,dyi),beta,pointer(dw)) - dyi += Y - end - return dw -end - -function conv2d_grad_x!(dx::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4}; - padding=0, stride=1, dilation=1, mode=0, alpha=1) where T - # dx = dy*w' - Wx,Hx,Cx,Nx = size(dx) - Ww,Hw,C1,C2 = size(w) - Wy,Hy,Cy,Ny = size(dy) - # if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end - @assert Cx==C1 && Cy==C2 && Ny==Nx - x2dims = im2col_dims(w,dy) - x2 = similar(dx, x2dims) - # op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix. - Y,M,N,K = Wy*Hy*Cy,Wy*Hy,Ww*Hw*Cx,Cy - alpha,beta = T(alpha),T(0) - (p1,p2) = psize(padding,dx) - (s1,s2) = psize(stride,dx) - (d1,d2) = psize(dilation,dx) - dyi = 1 - @inbounds for n in 1:Nx - gemm!('N','T',M,N,K,alpha,pointer(dy,dyi),pointer(w),beta,pointer(x2)) - col2im2d!(w,dx,x2,n,p1,p2,s1,s2,d1,d2,mode) - dyi += Y - end - return dx -end - -function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2}, - n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,C1,C2 = w - xn = view(x, :, :, :, n) - cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(1,1), mode == 0}() - im2col_2d!(xn,x2,cdims) - return x2 -end - -function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2}, - n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,C1,C2 = size(w) - xn = view(x, :, :, :, n) - cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}() - im2col_2d!(xn,x2,cdims) - return x2 -end - -function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2}, - n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,C1,C2 = w - xn = view(x, :, :, :, n) - col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode) - return x -end - -function col2im2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2}, - n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T - Wx,Hx,Cx,Nx = size(x) - Ww,Hw,C1,C2 = size(w) - xn = view(x, :, :, :, n) - col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode) - return x -end - -function conv3d!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; - padding=0, stride=1, dilation = 1, mode=0, alpha=T(1)) where T - if mode != 0 && mode != 1; throw(ArgumentError("conv3d only supports mode=0 or 1.")); end - Wx,Hx,Dx,Cx,Nx = size(x) - Ww,Hw,Dw,C1,C2 = size(w) - if Cx!=C1; throw(DimensionMismatch()); end - Wy,Hy,Dy,Cy,Ny = size(y) - # @assert Cy==C2 && Ny==Nx - x2dims = im2col_dims(w,y) - x2 = similar(x, x2dims) - (p1,p2,p3) = psize(padding,x) - (s1,s2,s3) = psize(stride,x) - (d1,d2,d3) = psize(dilation,x) - M,N,K,Y = Wy*Hy*Dy,Cy,Ww*Hw*Dw*Cx,Wy*Hy*Dy*Cy - yidx = 1 - W = reshape(w, (size(w, 1),:,C1,C2)) - @inbounds for n in 1:Nx - im2col3d!(w, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, mode) - gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(W),T(0),pointer(y,yidx)) - yidx += Y - end - return y -end - -function conv3d_grad_w!(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}; - padding=0, stride=1, dilation = 1, mode=0, alpha=1) where T - # dw = x'*dy - Wx,Hx,Dx,Cx,Nx = size(x) - Ww,Hw,Dw,C1,C2 = size(dw) - Wy,Hy,Dy,Cy,Ny = size(dy) - # if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end - @assert Cx==C1 && Cy==C2 && Ny==Nx - x2dims = im2col_dims(dw,dy) - x2 = similar(x, x2dims) - # op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix. - Y,M,N,K = Wy*Hy*Dy*Cy,Ww*Hw*Dw*Cx,Cy,Wy*Hy*Dy - alpha,beta = T(alpha),T(1) - (p1,p2,p3) = psize(padding,x) - (s1,s2,s3) = psize(stride,x) - (d1,d2,d3) = psize(dilation,x) - dyi = 1 - @inbounds for n in 1:Nx - im2col3d!(dw, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, mode) - gemm!('T','N',M,N,K,alpha,pointer(x2),pointer(dy,dyi),beta,pointer(dw)) - dyi += Y - end - return dw -end - -function conv3d_grad_x!(dx::AbstractArray{T,5}, w::AbstractArray{T,5}, dy::AbstractArray{T,5}; - padding=0, stride=1, dilation = 1, mode=0, alpha=1) where T - # dx = dy*w' - Wx,Hx,Dx,Cx,Nx = size(dx) - Ww,Hw,Dw,C1,C2 = size(w) - Wy,Hy,Dy,Cy,Ny = size(dy) - # if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end - @assert Cx==C1 && Cy==C2 && Ny==Nx - x2dims = im2col_dims(w,dy) - x2 = similar(dx, x2dims) - # op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix. - Y,M,N,K = Wy*Hy*Dy*Cy,Wy*Hy*Dy,Ww*Hw*Dw*Cx,Cy - alpha,beta = T(alpha),T(0) - (p1,p2,p3) = psize(padding,dx) - (s1,s2,s3) = psize(stride,dx) - (d1,d2,d3) = psize(dilation,dx) - dyi = 1 - @inbounds for n in 1:Nx - gemm!('N','T',M,N,K,alpha,pointer(dy,dyi),pointer(w),beta,pointer(x2)) - col2im3d!(w,dx,x2,n,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode) - dyi += Y - end - return dx -end - -function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArray{T,2}, - n::Int, p1::Int, p2::Int, p3::Int, s1::Int, s2::Int, - s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T - Wx,Hx,Dx,Cx,Nx = size(x) - Ww,Hw,Dw,C1,C2 = size(w) - xn = view(x, :, :, :, :, n) - im2col_3d!(xn,x2,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode) - return x2 -end - -function col2im3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArray{T,2}, - n::Int, p1::Int, p2::Int, p3::Int, s1::Int, s2::Int, - s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T - Wx,Hx,Dx,Cx,Nx = size(x) - Ww,Hw,Dw,C1,C2 = size(w) - xn = view(x, :, :, :, :, n) - col2im_3d!(x2,xn,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode) - return x -end diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl new file mode 100644 index 000000000..b556a9d14 --- /dev/null +++ b/src/impl/conv_direct.jl @@ -0,0 +1,148 @@ +## This file contains direct Julia implementations of 2d and 3d convolutions + +# Helper functions for restricting x/w overreach +function clamp_lo(x, w) + idx = 1 + while idx <= length(x) && x[idx] <= 0 + idx += 1 + end + return (x[idx:end], w[idx:end]) +end +function clamp_hi(x, w, L) + idx = length(x) + while idx >= 1 && x[idx] > L + idx -= 1 + end + return (x[1:idx], w[1:idx]) +end + +""" + conv_direct!(y, x, w, cdims; alpha=1, beta=0) + +Direct convolution implementation; used for debugging, tests, and mixing/matching of +strange datatypes within a single convolution. Uses naive nested for loop implementation +and does not attempt to optimize performance. Rather, this implementation is intended to +be maximally understandable and debuggable, to aid in testing other, more performant +implementations. We also explicitly support mixing and matching of strange datatypes, +so that if the user really wants to convolve an image of `UInt8`'s with a `Float16` +kernel, storing the result in a `Float32` output, there is at least a function call +for that madness. + +The keyword arguments `alpha` and `beta` control accumulation behavior; this function +calculates `y = alpha * x * w + beta * y`, therefore by setting `beta` to a nonzero +value, the user is able to accumulate values into a preallocated `y` buffer, or by +setting `alpha` to a nonunitary value, an arbitrary gain factor can be applied. + +By defaulting `beta` to `false`, we make use of the Bradbury promotion trick to override +`NaN`'s that may pre-exist within our output buffer, as `false*NaN == 0.0`, whereas +`0.0*NaN == NaN`. Only set `beta` if you are certain that none of the elements within +`y` are `NaN`. + +The basic implementation performs 3-dimensional convolution; 1-dimensional and 2- +dimensional casesa are supported by simply reshaping `y`, `x` and `w`, for which +wrapper methods are available. +""" +conv_direct! + +@timeit_debug to function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, + w::AbstractArray{wT,5}, cdims::DenseConvDims; + alpha::yT = yT(1), beta = false) where {yT, xT, wT} + check_dims(size(x), size(w), size(y), cdims) + + width, height, depth = input_size(cdims) + kernel_w, kernel_h, kernel_d = kernel_size(cdims) + out_c = channels_out(cdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) + dil_w, dil_h, dil_d = dilation(cdims) + stride_w, stride_h, stride_d = stride(cdims) + out_width, out_height, out_depth = output_size(cdims) + + # If we're doing crosscorr instead of conv, then don't bother to flip `w` + if !flipkernel(cdims) + w = w[end:-1:1, end:-1:1, end:-1:1, :, :] + end + + # A helper function to project from output (w, h) to input (input_w, input_h) + @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + + # explicit formulation of convolution. Oh hoisting gods, hear my plea. + @inbounds for batch in 1:size(x)[end], + c_out in 1:out_c, + h_idx in 1:out_height, + w_idx in 1:out_width, + d_idx in 1:out_depth + + # Starting points of the window of x we're going to grab + x_w = project(w_idx, stride_w, pad_w_lo) + x_h = project(h_idx, stride_h, pad_h_lo) + x_d = project(d_idx, stride_d, pad_d_lo) + + # Grow that starting point into ranges + x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1)) + x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1)) + x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1)) + w_widxs = 1:kernel_w + w_hidxs = 1:kernel_h + w_didxs = 1:kernel_d + + # Clamp the ranges to simulate padding + x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs) + x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width) + x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs) + x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height) + x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs) + x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth) + + # Grab our slices + x_slice = view(x, x_widxs, x_hidxs, x_didxs, :, batch) + w_slice = view(w, w_widxs, w_hidxs, w_didxs, :, c_out) + + # Do the dotproduct dance, then weight by alpha/beta and git 'er done + dotprod = sum(x_slice .* w_slice) + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) + + beta*y[w_idx, h_idx, d_idx, c_out, batch] + end + + return y +end + +## Gradient definitions +""" + ∇conv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) + +Calculate the gradient imposed upon `x` in the convolution `y = x * w`. +""" +∇conv_data_direct! + +@timeit_debug to function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, + w::AbstractArray{wT,5}, cdims::DenseConvDims; + alpha::xT=xT(1), beta=false) where {xT, yT, wT} + w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]) + dy = predilate(dy, stride(cdims)) + ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims), + dilation=dilation(cdims), + flipkernel=flipkernel(cdims)) + dx = conv_direct!(dx, dy, w, ctdims; alpha=alpha, beta=beta) + return transpose_swapbatch(dx) +end + +""" + ∇conv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) + +Calculate the gradient imposed upon `w` in the convolution `y = x * w`. +""" +∇conv_filter_direct! + +@timeit_debug to function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, + dy::AbstractArray{yT,5}, cdims::DenseConvDims; + alpha::wT=wT(1), beta=false) where {xT, yT, wT} + x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]) + dy = transpose_swapbatch(predilate(dy, stride(cdims))) + ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), + stride=dilation(cdims)) + conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta) + if flipkernel(cdims) + dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] + end + return dw +end \ No newline at end of file diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl new file mode 100644 index 000000000..12b808a4d --- /dev/null +++ b/src/impl/conv_im2col.jl @@ -0,0 +1,366 @@ +## This file contains im2col-backed implementations of convolution for 2d and 3d +## convolutions. Expect to see a lot of indexing. + +# Helper functions for flipkernel-induced dyslexia +@inline function kernel_index(w, h, d, cdims::ConvDims{N, S, P, D, false}) where {N, S, P, D} + kernel_w, kernel_h, kernel_d = kernel_size(cdims) + return (kernel_w - w + 1, kernel_h - h + 1, kernel_d - d + 1) +end +@inline function kernel_index(w, h, d, cdim::ConvDims{N, S, P, D, true}) where {N, S, P, D} + return (w, h, d) +end + +""" + conv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) + +Perform a convolution using im2col and GEMM, store the result in `y`. The kwargs +`alpha` and `beta` control accumulation behavior; internally this operation is +implemented as a matrix multiply that boils down to `y = alpha * x * w + beta * y`, thus +by setting `beta` to a nonzero value, multiple results can be accumulated into `y`, or +by setting `alpha` to a nonunitary value, various gain factors can be applied. + +Note for the particularly performance-minded, you can provide a pre-allocated `col`, +which should eliminate any need for large allocations within this method. +""" +@timeit_debug to function conv_im2col!( + y::AbstractArray{T,5}, x::AbstractArray{T,5}, + w::AbstractArray{T,5}, cdims::DenseConvDims; + col::AbstractArray{T,2}=similar(x, im2col_dims(cdims)), + alpha::T=T(1), beta::T=T(0)) where {T} + check_dims(size(x), size(w), size(y), cdims) + + # COL * W -> Y + # [M x K] * [K x N] -> [M x N] + # + # M: output spatial resolution + # N: output channels + # K: size of input "patch" (kernel size and input channels combined) + # + # In english, we're grabbing each input patch and laying them out along + # the M dimension in `col`, so that the GEMM call below multiplies each + # kernel (which is kernel_h * kernel_w * channels_in elments long) is + # dotproducted with that input patch, effectively computing a convolution + # in a somewhat memory-wasteful but easily-computed way (since we already + # have an extremely highly-optimized GEMM call available in BLAS). + M = prod(output_size(cdims)) + N = channels_out(cdims) + K = prod(kernel_size(cdims))*channels_in(cdims) + + @inbounds for batch_idx in 1:size(x,5) + # We invoke `@timeit_debug` on the outside of `im2col!()` because inference + # doesn't like us putting it on the inside. + @timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims) + col_ptr = pointer(col) + w_ptr = pointer(w) + y_ptr = pointer(y, (batch_idx - 1)*M*N + 1) + @timeit_debug to "gemm!" gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) + end + return y +end + +""" + ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); alpha=1, beta=0) + +Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`. +See the documentation for `conv_im2col!()` for explanation of optional parameters. +""" +@timeit_debug to function ∇conv_filter_im2col!( + dw::AbstractArray{T,5}, x::AbstractArray{T,5}, + dy::AbstractArray{T,5}, cdims::DenseConvDims; + col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)), + alpha::T=T(1), beta::T=T(0)) where {T} + check_dims(size(x), size(dw), size(dy), cdims) + + # COL' * dY -> dW + # [M x K] * [K x N] -> [M x N] + # + # M: size of input "patch" (kernel size and input channels combined) + # N: output channels + # K: output spatial resolution + # + # In english, we're grabbing each input patch and laying them out along + # the K dimension in `col`, then multiplying in `dY` to compute a dot + # product between all pixels in the input that were multiplied by a single + # position in the W kernel, and all output pixels of the same location, + # across output channels. This slice of `col` therefore constitutes every + # input pixel that touched a particular element of the kernel. + # + # This is identical to a convolution between x and a dimension-permuted dY, + # where we + + M = prod(kernel_size(cdims))*channels_in(cdims) + N = channels_out(cdims) + K = prod(output_size(cdims)) + + @inbounds for batch_idx in 1:size(x,5) + # We invoke `@timeit_debug` on the outside of `im2col!()` because inference + # doesn't like us putting it on the inside. + @timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims) + col_ptr = pointer(col) + dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1) + dw_ptr = pointer(dw) + @timeit_debug to "gemm!" gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) + + # Because we accumulate over batches in this loop, we must set `beta` equal + # to `1.0` from this point on. + beta = T(1) + end + return dw +end + +""" + ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) + +Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`. +See the documentation for `conv_im2col!()` for explanation of other parameters. +""" +@timeit_debug to function ∇conv_data_im2col!( + dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, + w::AbstractArray{T,5}, cdims::DenseConvDims; + col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)), + alpha::T=T(1), beta::T=T(0)) where {T} + check_dims(size(dx), size(w), size(dy), cdims) + + # dY W' -> dX + # [M x K] * [K x N] -> [M x N] + # + # M: output spatial resolution + # N: size of input "patch" (kernel size and input channels combined) + # K: output channels + # + # In english, we're taking the output image and laying it out by pixel, + # with channels lying along the `K` dimension in `col`. We then multiply + # in `W'` to compute a dot product between each pixel location and the + # entire kernel. This dot product therefore constitutes every output pixel + # that was a function of a particular input pixel. + # + # This is identical to a transposed convolution between dY and W + + M = prod(output_size(cdims)) + N = prod(kernel_size(cdims))*channels_in(cdims) + K = channels_out(cdims) + + @inbounds for batch_idx in 1:size(dx, 5) + dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1) + w_ptr = pointer(w) + col_ptr = pointer(col) + @timeit_debug to "gemm!" gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) + @timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims) + end + return dx +end + + + + + +""" + im2col!(col, x, cdims) + +Converts a 3d image `x` into a matrix `col` for usage with GEMM-calculated convolution. +Patches of `x` of size (kernel_w, kernel_h, kernel_d, C_in) will be extracted and laid +out along the rows of `col`, one for each output pixel. This routine is used by all +im2col-based convolutions, just with extra singleton dimensions added in the case of `2d` +or `1d` images. +""" +function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, + cdims::ConvDims) where {T} + if spatial_dims(cdims) != 3 + throw(DimensionMismatch("im2col!() only accepts 3d convoluitional inputs")) + end + + # Extract those nice, compile-time constant type parameters from `cdims`. + width, height, depth = input_size(cdims) + kernel_w, kernel_h, kernel_d = kernel_size(cdims) + C_in = channels_in(cdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) + dil_w, dil_h, dil_d = dilation(cdims) + stride_w, stride_h, stride_d = stride(cdims) + out_width, out_height, out_depth = output_size(cdims) + + # Reshape col for easy access. + col_reshaped = reshape(col, ( + # Output resolution + out_width, + out_height, + out_depth, + + # By input patch size + kernel_w, + kernel_h, + kernel_d, + C_in, + )) + + padded_regions, central_region = calc_padding_regions(cdims) + + # A helper function to project from output (w, h) to input (input_w, input_h) + @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + + + # We begin by copying the central region of the image which requires no padding at all. + # Eliminating the branches of the fully generalized version below gives us a nice + # speedup on the majority of the data. + @timeit_debug to "im2col!() - central region" begin + @inbounds for c in 1:C_in + # Unpack "central region" + w_region, h_region, d_region = central_region + + for kd in 1:kernel_d, + kh in 1:kernel_h, + kw in 1:kernel_w, + d in d_region, + h in h_region, + w in w_region + + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + kidxs = kernel_index(kw, kh, kd, cdims) + + xval::T = x[input_kw, input_kh, input_kd, c] + col_reshaped[w, h, d, kidxs..., c] = xval + end + end + end + + # For each "padded region", we run the fully general version + @timeit_debug to "im2col!() - padded region" begin + @inbounds for (w_region, h_region, d_region) in padded_regions + for c in 1:C_in, + d in d_region, + h in h_region, + w in w_region, + kd in 1:kernel_d, + kh in 1:kernel_h, + kw in 1:kernel_w + + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + + kidxs = kernel_index(kw, kh, kd, cdims) + + # If this d is off the edge, then deal with the entire plane + # in one fell swoop, like a ravenous flock of crows. CAW CAW. + if input_kd <= 0 || input_kd > depth + for kh in 1:kernel_h, + kw in 1:kernel_w + col_reshaped[w, h, d, kidxs..., c] = T(0) + end + continue + end + + # Same for `h`, but in this case it's only a line, not a plane. + # This results in slightly less caw'ing. + if input_kh <= 0 || input_kh > height + for kw in 1:kernel_w + col_reshaped[w, h, d, kidxs..., c] = T(0) + end + continue + end + + # If this `w` is off the edge it and only it gets cleared out + if input_kw <= 0 || input_kw > width + col_reshaped[w, h, d, kidxs..., c] = T(0) + continue + end + + # Copy the data over + xval::T = x[input_kw, input_kh, input_kd, c] + col_reshaped[w, h, d, kidxs..., c] = xval + end + end + end +end + + +""" + col2im!(x, col, cdims) + +Does the inverse of `im2col!()`, converting `col` back into a 3d image, used for backward +passes, transposed convolutions, etc... + +Note that this method has not been optimized in the same way as `im2col()` has, because +it is slightly more complicated due to the more chaotic data access patterns, and I'm not +desperate enough yet. +""" +col2im! + +function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, + cdims::ConvDims) where T + if spatial_dims(cdims) != 3 + throw(DimensionMismatch("col2im!() only accepts 3d convoluitional inputs")) + end + + # Extract those nice, compile-time constant type parameters from `cdims`. + width, height, depth = input_size(cdims) + kernel_w, kernel_h, kernel_d = kernel_size(cdims) + C_in = channels_in(cdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) + dil_w, dil_h, dil_d = dilation(cdims) + stride_w, stride_h, stride_d = stride(cdims) + out_width, out_height, out_depth = output_size(cdims) + + # TODO: Rewrite this method so we don't have this fill!() at the beginning! + # Calculate each output pixel once rather than accumulating into it? + fill!(x, T(0)) + + # Reshape col for easy access. + col_reshaped = reshape(col, ( + # Output resolution + out_width, + out_height, + out_depth, + + # By input patch size + kernel_w, + kernel_h, + kernel_d, + C_in, + )) + + # A helper function to project from output (w, h) to input (input_w, input_h) + @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + + @inbounds for c in 1:C_in + for kd in 1:kernel_d, + kh in 1:kernel_h, + kw in 1:kernel_w + + for d in 1:out_depth + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d + + # If this d is off the edge, then deal with the entire plane + # in one fell swoop, like a ravenous flock of crows. CAW CAW. + if input_kd <= 0 || input_kd > depth + continue + end + + for h in 1:out_height + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + + # Same for `h`, but in this case it's only a line, not a plane. + # This results in slightly less caw'ing. + if input_kh <= 0 || input_kh > height + continue + end + + for w in 1:out_width + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + + # If this `w` is off the edge, only it gets cleared out. + if input_kw <= 0 || input_kw > width + continue + end + + # Copy the data over + kidxs = kernel_index(kw, kh, kd, cdims) + cval::T = col_reshaped[w, h, d, kidxs..., c] + x[input_kw, input_kh, input_kd, c] += cval + end + end + end + end + end +end diff --git a/src/impl/depthwiseconv_direct.jl b/src/impl/depthwiseconv_direct.jl new file mode 100644 index 000000000..5dc001625 --- /dev/null +++ b/src/impl/depthwiseconv_direct.jl @@ -0,0 +1,158 @@ +## This file contains direct Julia implementations of depwthwise convolutions + +""" + depthwiseconv_direct!(y, x, w, cdims; alpha=1, beta=0) + +Direct depthwise convolution implementation; used for debugging, tests, and mixing/ +matching of strange datatypes within a single convolution. Uses naive nested for loop +implementation and does not attempt to optimize performance. Rather, this implementation +is intended to be maximally understandable and debuggable, to aid in testing other, more +performant implementations. We also explicitly support mixing and matching of strange +datatypes, so that if the user really wants to convolve an image of `UInt8`'s with a +`Float16` kernel, storing the result in a `Float32` output, there is at least a function +call for that madness. + +One subtlety about depthwise convolutions; the shape of a depthwise convolutional kernel +is `(spatial_dims..., C_mult, C_in)`, so the axis that must match with the number of +channels in `x` is the last, not the second-to-last, as in a normal dense convolution. + +See the docstring for `conv_direct!()` for more on the optional parameters. +""" +@timeit_debug to function depthwiseconv_direct!( + y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, + w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; + alpha::yT = yT(1), beta::yT = yT(0)) where {yT, xT, wT} + check_dims(size(x), size(w), size(y), cdims) + + width, height, depth = input_size(cdims) + kernel_w, kernel_h, kernel_d = kernel_size(cdims) + out_c = channels_out(cdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) + dil_w, dil_h, dil_d = dilation(cdims) + stride_w, stride_h, stride_d = stride(cdims) + out_width, out_height, out_depth = output_size(cdims) + + # If we're doing crosscorr instead of conv, then don't bother to flip `w` + if !flipkernel(cdims) + w = w[end:-1:1, end:-1:1, end:-1:1, :, :] + end + + # A helper function to project from output (w, h) to input (input_w, input_h) + @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + + # explicit formulation of convolution. Oh hoisting gods, hear my plea. + @inbounds for batch in 1:size(x)[end], + c_mult in 1:channel_multiplier(cdims), + c_in in 1:channels_in(cdims), + h_idx in 1:out_height, + w_idx in 1:out_width, + d_idx in 1:out_depth + + # Starting points of the window of x we're going to grab + x_w = project(w_idx, stride_w, pad_w_lo) + x_h = project(h_idx, stride_h, pad_h_lo) + x_d = project(d_idx, stride_d, pad_d_lo) + + # Grow that starting point into ranges + x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1)) + x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1)) + x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1)) + w_widxs = 1:kernel_w + w_hidxs = 1:kernel_h + w_didxs = 1:kernel_d + + # Clamp the ranges to simulate padding + x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs) + x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width) + x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs) + x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height) + x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs) + x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth) + + # Grab our slices (for a single channel pairing, as this is depthwise) + c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult + x_slice = view(x, x_widxs, x_hidxs, x_didxs, c_in, batch) + w_slice = view(w, w_widxs, w_hidxs, w_didxs, c_mult, c_in) + + # Do the dotproduct dance, then weight by alpha/beta and git 'er done + dotprod = sum(x_slice .* w_slice) + prev_yval::yT = beta*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) + prev_yval + end + + return y +end + +""" + ∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) + +Calculate the gradient imposed upon `x` in the depthwise convolution `y = x * w`. +We make use of the fact that a depthwise convolution is equivalent to `C_in` separate +normal convolutions between that channel of `x` and the `C_mult` different kernels that +get applied to it. The output of such a convolution is the gradient imposed upon that +particular channel of `x`, and so we simply walk through `x`, calculating the gradient +for each batch and channel independently. +""" +∇depthwiseconv_data_direct! + +@timeit_debug to function ∇depthwiseconv_data_direct!( + dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, + w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; + alpha::xT=xT(1), beta::xT=xT(0)) where {xT, yT, wT} + # We do a separate convolution for each channel in x + @inbounds for cidx in 1:channels_in(cdims) + # For this batch and in-channel, we have a normal transposed convolution + # between this slice of `x` and the corresponding slices of `w` and `dy`: + dx_slice = view(dx, :, :, :, cidx:cidx, :) + C_mult = channel_multiplier(cdims) + dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) + w_slice = permutedims(view(w, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) + + # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out + # channels appropriately for this one convolution. + cdims_slice = DenseConvDims(cdims; + C_in=1, + C_out=channel_multiplier(cdims), + ) + + ∇conv_data_direct!(dx_slice, dy_slice, w_slice, cdims_slice; + alpha=alpha, beta=beta) + end + return dx +end + +""" + ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) + +Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`. +""" +∇depthwiseconv_filter_direct! + +@timeit_debug to function ∇depthwiseconv_filter_direct!( + dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, + dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims; + alpha::wT=wT(1),beta::wT=wT(0)) where {xT, yT, wT} + # We do a separate convolution for each channel in x + @inbounds for cidx in 1:channels_in(cdims) + # For this batch and in-channel, we have a normal transposed convolution + # between this slice of `x` and the corresponding slices of `w` and `dy`: + x_slice = view(x, :, :, :, cidx:cidx, :) + C_mult = channel_multiplier(cdims) + dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) + dw_slice = permutedims(view(dw, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) + + # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out + # channels appropriately for this one convolution. + cdims_slice = DenseConvDims(cdims; + C_in=1, + C_out=channel_multiplier(cdims), + ) + + ∇conv_filter_direct!(dw_slice, x_slice, dy_slice, cdims_slice; + alpha=alpha, beta=beta) + dw[:, :, :, :, cidx:cidx] .= permutedims(dw_slice, (1, 2, 3, 5, 4)) + end + return dw +end + + diff --git a/src/impl/depthwiseconv_im2col.jl b/src/impl/depthwiseconv_im2col.jl new file mode 100644 index 000000000..d87e07aac --- /dev/null +++ b/src/impl/depthwiseconv_im2col.jl @@ -0,0 +1,119 @@ +## This file contains adapter code for doing depthwise convolutions with im2col. + + +""" + depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) + +Perform a depthwise convolution using im2col and GEMM, store the result in `y`. + +See `conv_im2col!()` for an explanation of optional parameters. +""" +depthwiseconv_im2col! + +@timeit_debug to function depthwiseconv_im2col!( + y::AbstractArray{T,5}, x::AbstractArray{T,5}, + w::AbstractArray{T,5}, cdims::DepthwiseConvDims; + col::AbstractArray{T,2} = similar(x, im2col_dims(cdims)), + alpha=T(1), beta=T(0)) where T + check_dims(size(x), size(w), size(y), cdims) + + # This functions exactly the same as conv_im2col!(), except that we shard the + # incoming data into slices of single channels. This means that we need to walk + # each pointer forward individually, as done below, taking a single input channel + # and combining it with each kernel individually, before walking forward and doing + # the next input channel. + M = prod(output_size(cdims)) + N = channel_multiplier(cdims) + K = prod(kernel_size(cdims)) + + dcdims = DenseConvDims(cdims) + @inbounds for batch_idx in 1:size(x)[end] + # We invoke `@timeit_debug` on the outside of `im2col!()` because inference + # doesn't like us putting it on the inside. + @timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), dcdims) + + # We do a separate convolution for each channel in x, as we must + for c_in in 1:channels_in(cdims) + # Walk each pointer forward as we process each input channel + col_ptr = pointer(col, (c_in-1)*M*K+1) + w_ptr = pointer(w, (c_in-1)*K*N+1) + y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) + gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) + end + end + return y +end + +""" + ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta) + +Depthwise conv2d backward pass onto the weights using im2col and GEMM. +See the documentation for `conv_im2col!()` for explanation of optional parameters. +""" +∇depthwiseconv_filter_im2col! + +@timeit_debug to function ∇depthwiseconv_filter_im2col!( + dw::AbstractArray{T,5}, x::AbstractArray{T,5}, + dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; + col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)), + alpha=T(1), beta=T(0)) where T + check_dims(size(x), size(dw), size(dy), cdims) + + M = prod(kernel_size(cdims)) + N = channel_multiplier(cdims) + K = prod(output_size(cdims)) + + @inbounds for batch_idx in 1:size(x)[end] + # We invoke `@timeit_debug` on the outside of `im2col!()` because inference + # doesn't like us putting it on the inside. + @timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims) + + # We do a separate convolution for each channel in x, as we must + for c_in in 1:channels_in(cdims) + # Walk each pointer forward as we process each input channel + col_ptr = pointer(col, (c_in - 1)*M*K + 1) + dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) + dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) + + gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) + end + + # Because we accumulate over batches in this loop, we must set `beta` equal + # to `1.0` from this point on. + beta = T(1) + end + return dw +end + +""" + depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) + +Depwthwise conv2d backward pass onto the input using im2col and GEMM. +See the documentation for `conv_im2col!()` for explanation of optional parameters. +""" +∇depthwiseconv_data_im2col! + +@timeit_debug to function ∇depthwiseconv_data_im2col!( + dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, + w::AbstractArray{T,5}, cdims::DepthwiseConvDims; + col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)), + alpha=T(1), beta=T(0)) where T + check_dims(size(dx), size(w), size(dy), cdims) + + M = prod(output_size(cdims)) + N = prod(kernel_size(cdims)) + K = channel_multiplier(cdims) + + @inbounds for batch_idx in 1:size(dx)[end] + # We do a separate convolution for each channel in x, as we must + for cidx in 1:channels_in(cdims) + # Walk each pointer forward as we process each input channel + dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) + w_ptr = pointer(w, (cidx - 1)*K*N + 1) + col_ptr = pointer(col, (cidx - 1)*M*N + 1) + gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) + end + @timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims) + end + return dx +end \ No newline at end of file diff --git a/src/impl/padding_edges.jl b/src/impl/padding_edges.jl new file mode 100644 index 000000000..1d436ea56 --- /dev/null +++ b/src/impl/padding_edges.jl @@ -0,0 +1,101 @@ +""" + calc_padding_regions(dims) + +Padding is a jerk. A HUGE jerk that tries to sneak a bunch of conditionals and edge +cases (quite literally) into our beautiful stencil operations such as convolution, +pooling, etc... The way we deal with this is to, first, deal with everything in 3d, +and then define a single padding region helper function that returns the seven regions +that all 3d operations must deal with, including the central "unpadded" region where we +can run at full bore, not paying any attention to padding. +""" +function calc_padding_regions(dims) + width, height, depth = input_size(dims) + kernel_w, kernel_h, kernel_d = kernel_size(dims) + C_in = channels_in(dims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(dims) + dil_w, dil_h, dil_d = dilation(dims) + stride_w, stride_h, stride_d = stride(dims) + out_width, out_height, out_depth = output_size(dims) + + # Let us first calculate the number of rows/cols within which we must zero out some + # portion of the image patches we're copying over. The "spillage" here is the number + # of indices along a particular dimension for which a kernel will have some portion + # of its input domain overlapping the padding. If padding is zero, these values are + # all trivially zero. The low spillage is trivially the low padding divided by the + # stride; literally the number of shifts that overlap some padding. The high + # spillage is slightly more complicated; we first figure out how many elements of + # high padding are wasted (e.g. through strides not fitting to the end perfectly) + # subtract that from the high padding, then do the same: + calc_lo_spill(O, S, P) = min(ceil(Int, P/S), O) + @inline function calc_hi_spill(O, S, Pl, Ph, K, D, I) + wasted_Ph = (I + Pl + Ph - (K - 1)*D - 1)%S + return min(ceil(Int, (Ph - wasted_Ph)/S), O) + end + + spill_w_lo = calc_lo_spill(out_width, stride_w, pad_w_lo) + spill_w_hi = calc_hi_spill(out_width, stride_w, pad_w_lo, pad_w_hi, kernel_w, dil_w, width) + spill_h_lo = calc_lo_spill(out_height, stride_h, pad_h_lo) + spill_h_hi = calc_hi_spill(out_height, stride_h, pad_h_lo, pad_h_hi, kernel_h, dil_h, height) + spill_d_lo = calc_lo_spill(out_depth, stride_d, pad_d_lo) + spill_d_hi = calc_hi_spill(out_depth, stride_d, pad_d_lo, pad_d_hi, kernel_d, dil_d, depth) + + spill_w_hi_abs = out_width - spill_w_hi + 1 + spill_h_hi_abs = out_height - spill_h_hi + 1 + spill_d_hi_abs = out_depth - spill_d_hi + 1 + + # These are the regions we're going to have to run with cognizance of padding. + # There are six of them; one for each face of the cube image. We explicitly + # design this so that we run over `width` most tightly, in the expectation that + # this will generate better code for when `h` and `d` are singleton dimensions. + # We visualize this as a cube, indexed by dimensions (w, h, d). + padded_regions = ( + # First region is the lower-d WH face: + ( + 1:out_width, + 1:out_height, + 1:spill_d_lo, + ), + + # The next largest chunk we choose will be the lower-h WD faces; we always + # want to maximize going across full `w`, as its contiguous in memory. + ( + 1:out_width, + 1:spill_h_lo, + (spill_d_lo+1):(spill_d_hi_abs-1), + ), + # Then the upper-h WD face + ( + 1:out_width, + spill_h_hi_abs:out_height, + (spill_d_lo+1):(spill_d_hi_abs-1), + ), + + # Next, we fit the HD faces in, but without overlapping the `h` and `d` + # regions we've done before: + ( + 1:spill_w_lo, + (spill_h_lo+1):(spill_h_hi_abs-1), + (spill_d_lo+1):(spill_d_hi_abs-1), + ), + ( + spill_w_hi_abs:out_width, + (spill_h_lo+1):(spill_h_hi_abs-1), + (spill_d_lo+1):(spill_d_hi_abs-1) + ), + + # Last region is the higher-d WH face: + ( + 1:out_width, + 1:out_height, + spill_d_hi_abs:out_depth, + ), + ) + + # The central region that has no padding. + central_region = ( + (spill_w_lo+1):(spill_w_hi_abs - 1), + (spill_h_lo+1):(spill_h_hi_abs - 1), + (spill_d_lo+1):(spill_d_hi_abs - 1), + ) + return padded_regions, central_region +end \ No newline at end of file diff --git a/src/impl/pool.jl b/src/impl/pool.jl deleted file mode 100644 index b9ffa8deb..000000000 --- a/src/impl/pool.jl +++ /dev/null @@ -1,294 +0,0 @@ -function max_pooling2d_fwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4}, - width::Int, height::Int, channels::Int, num::Int, pooled_width::Int, - pooled_height::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, - stride_w::Int, stride_h::Int) where T - @inbounds for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width - hstart = (ph - 1)*stride_h - pad_h - wstart = (pw - 1)*stride_w - pad_w - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - - m = typemin(T) - for j in hstart:hend, i in wstart:wend - m = max(m, x[i, j, c, n]) - end - y[pw, ph, c, n] = m - end -end - -function maxpool2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}; - window::Dims{2}=(2,2), padding::Dims{2}=(0,0), - stride::Dims{2}=window) where T - Wx,Hx,Cx,Nx = size(x) - Wy,Hy,Cy,Ny = size(y) - (w1,w2) = window - (p1,p2) = padding - (s1,s2) = stride - max_pooling2d_fwd!(x,y,Wx,Hx,Cx,Nx,Wy,Hy,w1,w2,p1,p2,s1,s2) - return y -end - -function max_pooling2d_bwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4}, - grad_output::AbstractArray{T,4}, grad_input::AbstractArray{T,4}, width::Int, height::Int, - channels::Int, num::Int, pooled_width::Int, pooled_height::Int, kernel_w::Int, - kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, stride_h::Int) where T - - grad_input .= 0 - #pragma omp parallel for - for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width - hstart = (ph - 1) * stride_h - pad_h - wstart = (pw - 1) * stride_w - pad_w - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - maxval = y[pw, ph, c, n] - d_maxval = grad_output[pw, ph, c, n] - for h = hstart:hend, w = wstart:wend - if x[w, h, c, n] == maxval - grad_input[w, h, c, n] += d_maxval - end - end - end -end - -function maxpool2d_grad!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, y::AbstractArray{T,4}, x::AbstractArray{T,4}; - window::Dims{2}=(2,2), padding::Dims{2}=(0,0), - stride::Dims{2}=window) where T - Wx,Hx,Cx,Nx = size(x) - Wy,Hy,Cy,Ny = size(y) - (w1,w2) = window - (p1,p2) = padding - (s1,s2) = stride - max_pooling2d_bwd!(x,y,dy,dx,Wx,Hx,Cx,Nx,Wy,Hy,w1,w2,p1,p2,s1,s2) - return dx -end - - -function mean_pooling2d_fwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4}, - width::Int, height::Int, channels::Int, num::Int, pooled_width::Int, - pooled_height::Int, kernel_w::Int, kernel_h::Int,pad_w::Int, pad_h::Int, - stride_w::Int, stride_h::Int) where T - kernel_size = kernel_w * kernel_h - @inbounds for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width - hstart = (ph - 1) * stride_h - pad_h - wstart = (pw - 1) * stride_w - pad_w - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - - s = zero(T) - for j in hstart:hend, i in wstart:wend - s += x[i, j, c, n] - end - y[pw, ph, c, n] = s / kernel_size - end -end - -function meanpool2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}; - window::Dims{2}=(2,2), padding::Dims{2}=(0,0), - stride::Dims{2}=window) where T - Wx,Hx,Cx,Nx = size(x) - Wy,Hy,Cy,Ny = size(y) - (w1,w2) = window - (p1,p2) = padding - (s1,s2) = stride - mean_pooling2d_fwd!(x,y,Wx,Hx,Cx,Nx,Wy,Hy,w1,w2,p1,p2,s1,s2) - return y -end - -function mean_pooling2d_bwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4}, - width::Int, height::Int, channels::Int, num::Int, pooled_width::Int, - pooled_height::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, - stride_w::Int, stride_h::Int) where T - - x[:, :, :, :] .= 0 - kernel_size = kernel_w * kernel_h - - #pragma omp parallel for - for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width - hstart = (ph - 1) * stride_h - pad_h - wstart = (pw - 1) * stride_w - pad_w - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - - oval = y[pw, ph, c, n] / kernel_size - x[wstart:wend, hstart:hend, c, n] .+= oval - end -end - -function meanpool2d_grad!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, y::AbstractArray{T,4}, x::AbstractArray{T,4}; - window::Dims{2}=(2,2), padding::Dims{2}=(0,0), - stride::Dims{2}=window) where T - Wx,Hx,Cx,Nx = size(x) - Wy,Hy,Cy,Ny = size(y) - (w1,w2) = window - (p1,p2) = padding - (s1,s2) = stride - mean_pooling2d_bwd!(dx,dy,Wx,Hx,Cx,Nx,Wy,Hy,w1,w2,p1,p2,s1,s2) - return dx -end - -function max_pooling3d_fwd!(x::AbstractArray{T,5}, y::AbstractArray{T,5}, - width::Int, height::Int, depth::Int, channels::Int, num::Int, pooled_width::Int, - pooled_height::Int, pooled_depth::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, - pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int) where T - @inbounds for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width - dstart = (pd - 1)* stride_d - pad_d - hstart = (ph - 1)* stride_h - pad_h - wstart = (pw - 1)* stride_w - pad_w - - dend = min(dstart + kernel_d, depth) - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - - dstart = max(dstart, 0) + 1 - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - - m = typemin(T) - for k in dstart:dend, j in hstart:hend, i in wstart:wend - m = max(m, x[i, j, k, c, n]) - end - y[pw, ph, pd, c, n] = m - end -end - -function maxpool3d!(y::AbstractArray{T,5}, x::AbstractArray{T,5}; - window::Dims{3}=(2,2,2), padding::Dims{3}=(0,0,0), - stride::Dims{3}=window) where T - Wx,Hx,Dx,Cx,Nx = size(x) - Wy,Hy,Dy,Cy,Ny = size(y) - (w1,w2,w3) = psize(window, x) - (p1,p2,p3) = psize(padding, x) - (s1,s2,s3) = psize(stride, x) - max_pooling3d_fwd!(x,y,Wx,Hx,Dx,Cx,Nx,Wy,Hy,Dy,w1,w2,w3,p1,p2,p3,s1,s2,s3) - return y -end - -function max_pooling3d_bwd!(x::AbstractArray{T,5}, y::AbstractArray{T,5}, - grad_output::AbstractArray{T,5}, grad_input::AbstractArray{T,5}, width::Int, height::Int, depth::Int, - channels::Int, num::Int, pooled_width::Int, pooled_height::Int, pooled_depth::Int, - kernel_w::Int, kernel_h::Int, kernel_d::Int, pad_w::Int, pad_h::Int, pad_d::Int, - stride_w::Int, stride_h::Int, stride_d::Int) where T - - grad_input .= 0 - - #pragma omp parallel for - for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width - dstart = (pd - 1) * stride_h - pad_h - hstart = (ph - 1) * stride_h - pad_h - wstart = (pw - 1) * stride_w - pad_w - - dend = min(dstart + kernel_d, depth) - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - - dstart = max(dstart, 0) + 1 - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - - maxval = y[pw, ph, pd, c, n] - d_maxval = grad_output[pw, ph, pd, c, n] - for d = dstart:dend, h = hstart:hend, w = wstart:wend - if x[w, h, d, c, n] == maxval - grad_input[w, h, d, c, n] += d_maxval - end - end - end -end - -function maxpool3d_grad!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, y::AbstractArray{T,5}, x::AbstractArray{T,5}; - window::Dims{3}=(2,2,2), padding::Dims{3}=(0,0,0), - stride::Dims{3}=window) where T - Wx,Hx,Dx,Cx,Nx = size(x) - Wy,Hy,Dy,Cy,Ny = size(y) - (w1,w2,w3) = psize(window, x) - (p1,p2,p3) = psize(padding, x) - (s1,s2,s3) = psize(stride, x) - max_pooling3d_bwd!(x,y,dy,dx,Wx,Hx,Dx,Cx,Nx,Wy,Hy,Dy,w1,w2,w3,p1,p2,p3,s1,s2,s3) - return dx -end - -function mean_pooling3d_fwd!(x::AbstractArray{T,5}, y::AbstractArray{T,5}, - width::Int, height::Int, depth::Int, channels::Int, num::Int, pooled_width::Int, - pooled_height::Int, pooled_depth::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, - pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int) where T - - kernel_size = kernel_w * kernel_h * kernel_d - #pragma omp parallel for - @inbounds for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width - dstart = (pd - 1) * stride_d - pad_d - hstart = (ph - 1) * stride_h - pad_h - wstart = (pw - 1) * stride_w - pad_w - - dend = min(dstart + kernel_d, depth) - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - - dstart = max(dstart, 0) + 1 - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - - s = zero(T) - for k in dstart:dend, j in hstart:hend, i in wstart:wend - s += x[i, j, k, c, n] - end - y[pw, ph, pd, c, n] = s / kernel_size - end -end - -function meanpool3d!(y::AbstractArray{T,5}, x::AbstractArray{T,5}; - window::Dims{3}=(2,2), padding::Dims{3}=(0,0), - stride::Dims{3}=window) where T - Wx,Hx,Dx,Cx,Nx = size(x) - Wy,Hy,Dy,Cy,Ny = size(y) - (w1,w2,w3) = psize(window, x) - (p1,p2,p3) = psize(padding, x) - (s1,s2,s3) = psize(stride, x) - mean_pooling3d_fwd!(x,y,Wx,Hx,Dx,Cx,Nx,Wy,Hy,Dy,w1,w2,w3,p1,p2,p3,s1,s2,s3) - return y -end - -function mean_pooling3d_bwd!(grad_input::AbstractArray{T,5}, grad_output::AbstractArray{T,5}, - width::Int, height::Int, depth::Int, channels::Int, num::Int, pooled_width::Int, - pooled_height::Int, pooled_depth::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, - pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int) where T - - kernel_size = kernel_w * kernel_h * kernel_d - fill!(grad_input, 0.0) - - #pragma omp parallel for - for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width - dstart = (pd - 1) * stride_d - pad_d - hstart = (ph - 1) * stride_h - pad_h - wstart = (pw - 1) * stride_w - pad_w - dend = min(dstart + kernel_d, depth) - hend = min(hstart + kernel_h, height) - wend = min(wstart + kernel_w, width) - dstart = max(dstart, 0) + 1 - hstart = max(hstart, 0) + 1 - wstart = max(wstart, 0) + 1 - - grad_input[wstart:wend, hstart:hend, dstart:dend, c, n] .+= grad_output[pw, ph, pd, c, n] ./ kernel_size - end -end - -function meanpool3d_grad!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, y::AbstractArray{T,5}, x::AbstractArray{T,5}; - window::Dims{3}=(2,2,2), padding::Dims{3}=(0,0,0), - stride::Dims{3}=window) where T - Wx,Hx,Dx,Cx,Nx = size(x) - Wy,Hy,Dy,Cy,Ny = size(y) - (w1,w2,w3) = psize(window, x) - (p1,p2,p3) = psize(padding, x) - (s1,s2,s3) = psize(stride, x) - mean_pooling3d_bwd!(dx,dy,Wx,Hx,Dx,Cx,Nx,Wy,Hy,Dy,w1,w2,w3,p1,p2,p3,s1,s2,s3) - return dx -end diff --git a/src/impl/pooling_direct.jl b/src/impl/pooling_direct.jl new file mode 100644 index 000000000..93125f1fd --- /dev/null +++ b/src/impl/pooling_direct.jl @@ -0,0 +1,250 @@ +using Statistics + +# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing +# the inner loop operation and a few initialization parameters. +for name in (:max, :mean) + @eval function $((Symbol("$(name)pool_direct!")))( + y::AbstractArray{T,5}, x::AbstractArray{T,5}, + pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T} + check_dims(size(x), size(y), pdims) + + width, height, depth = input_size(pdims) + kernel_w, kernel_h, kernel_d = kernel_size(pdims) + out_c = channels_out(pdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(pdims) + dil_w, dil_h, dil_d = dilation(pdims) + stride_w, stride_h, stride_d = stride(pdims) + out_width, out_height, out_depth = output_size(pdims) + + # We use calc_padding_regions to split outselves up into separate regions that may or + # may not need to worry about padding: + padded_regions, central_region = calc_padding_regions(pdims) + + # A helper function to project from output (w, h) to input (input_w, input_h) + @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + + # If we're doing mean pooling, we represent division by kernel size by rolling it + # into the `alpha` multiplier. + if $(name == :mean) + alpha = alpha/prod(kernel_size(pdims)) + end + + # Each loop, we initialize `m` to something, set that here. + m_init = if $(name == :max) + typemin(T) + elseif $(name == :mean) + T(0) + else + error("Unimplemented codegen path") + end + + # Start with the central region + w_region, h_region, d_region = central_region + @inbounds for batch_idx in 1:size(x)[end], + c in 1:out_c, + d in d_region, + h in h_region, + w in w_region + + # Initialize `m` to `0.0`, or `typemin(T)` or whatever. + m = m_init + + for kd in 1:kernel_d, + kh in 1:kernel_h, + kw in 1:kernel_w + + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + + # This conditional will be optimized away at compile time + if $(name == :max) + m = max(m, x[input_kw, input_kh, input_kd, c, batch_idx]) + elseif $(name == :mean) + m += x[input_kw, input_kh, input_kd, c, batch_idx] + else + error("Unimplemented codegen path") + end + end + y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx] + end + + # Next, the padded regions + @inbounds for (w_region, h_region, d_region) in padded_regions + for batch_idx in 1:size(x)[end], + c in 1:out_c, + d in d_region, + h in h_region, + w in w_region + + # In these loops, we have to check that we're not reaching off the edge, we + # do so by putting in a bunch of conditionals. :/ + m = m_init + for kd in 1:kernel_d + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d + if input_kd <= 0 || input_kd > depth + m = max(m, 0.0) + continue + end + + for kh in 1:kernel_h + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + if input_kh <= 0 || input_kh > height + m = max(m, 0.0) + continue + end + + for kw in 1:kernel_w + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + if input_kw <= 0 || input_kw > width + m = max(m, 0.0) + continue + end + + if $(name == :max) + m = max(m, x[input_kw, input_kh, input_kd, c, batch_idx]) + elseif $(name == :mean) + m += x[input_kw, input_kh, input_kd, c, batch_idx] + else + error("Unimplemented codegen path") + end + end + end + end + y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx] + end + end + + # Return `y` + return y + end + + # Same story for gradients, and although this is very similar to the forward pass, + # it's unfortunately different enough that I think we need a separate function. :( + @eval function $((Symbol("∇$(name)pool_direct!")))( + dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, + y::AbstractArray{T,5}, x::AbstractArray{T,5}, + pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T} + check_dims(size(x), size(dy), pdims) + + width, height, depth = input_size(pdims) + kernel_w, kernel_h, kernel_d = kernel_size(pdims) + out_c = channels_out(pdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(pdims) + dil_w, dil_h, dil_d = dilation(pdims) + stride_w, stride_h, stride_d = stride(pdims) + out_width, out_height, out_depth = output_size(pdims) + + # We use calc_padding_regions to split outselves up into separate regions that + # may or may not need to worry about padding: + padded_regions, central_region = calc_padding_regions(pdims) + + # A helper function to project from output (w, h) to input (input_w, input_h) + @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + + # If we're doing mean pooling, we represent division by kernel size by rolling + # it into the `alpha` multiplier. + if $(name == :mean) + alpha = alpha/prod(kernel_size(pdims)) + end + + # Start with the central region + w_region, h_region, d_region = central_region + @inbounds for batch_idx in 1:size(x)[end], + c in 1:out_c, + d in d_region, + h in h_region, + w in w_region + + # Grab the output at this index for future use + y_idx = y[w, h, d, c, batch_idx] + dy_idx = dy[w, h, d, c, batch_idx] + maxpool_already_chose = false + + for kd in 1:kernel_d, + kh in 1:kernel_h, + kw in 1:kernel_w + + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + + # This conditional will be optimized away at compile time, + # or my name isn't shengdan jingyu + x_idxs = (input_kw, input_kh, input_kd, c, batch_idx) + if $(name == :max) + # If it's equal; this is the one we chose. We only choose one per + # kernel window, all other elements of dx must be zero. + if y_idx == x[x_idxs...] && !maxpool_already_chose + dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...] + maxpool_already_chose = true + # Maxpooling does not support `beta` right now. :( + #else + # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] + end + elseif $(name == :mean) + # Either does meanpool :( + dx[x_idxs...] = dy_idx*alpha + dx[x_idxs...] + else + error("Unimplemented codegen path") + end + end + end + + # Next, the padded regions + @inbounds for (w_region, h_region, d_region) in padded_regions + for batch_idx in 1:size(x)[end], + c in 1:out_c, + d in d_region, + h in h_region, + w in w_region + + # Grab the incoming gradient at this index for future use + y_idx = dy[w, h, d, c, batch_idx] + dy_idx = dy[w, h, d, c, batch_idx] + maxpool_already_chose = false + + # In these loops, we have to check that we're not reaching off the edge, + # we do so by putting in a bunch of conditionals. :/ + for kd in 1:kernel_d + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d + if input_kd <= 0 || input_kd > depth + continue + end + + for kh in 1:kernel_h + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + if input_kh <= 0 || input_kh > height + continue + end + + for kw in 1:kernel_w + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + if input_kw <= 0 || input_kw > width + continue + end + + # Same as above + x_idxs = (input_kw, input_kh, input_kd, c, batch_idx) + if $(name == :max) + if y_idx == x[x_idxs...] && !maxpool_already_chose + dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...] + maxpool_already_chose = true + #else + # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] + end + elseif $(name == :mean) + dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...] + else + error("Unimplemented codegen path") + end + end + end + end + end + end + + # Return `dx` + return dx + end +end diff --git a/src/linalg.jl b/src/linalg.jl deleted file mode 100644 index b79e34d6f..000000000 --- a/src/linalg.jl +++ /dev/null @@ -1,30 +0,0 @@ -## Low level gemm! call with pointers -## Borrowed from Knet.jl - -using LinearAlgebra -using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc - -# C := alpha*op(A)*op(B) + beta*C, where: -# op(X) is one of op(X) = X, or op(X) = XT, or op(X) = XH, -# alpha and beta are scalars, -# A, B and C are matrices: -# op(A) is an m-by-k matrix, -# op(B) is a k-by-n matrix, -# C is an m-by-n matrix. - -for (gemm, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32)) - @eval begin - function gemm!(transA::Char, transB::Char, M::Int, N::Int, K::Int, alpha::($elty), A::Ptr{$elty}, B::Ptr{$elty}, beta::($elty), C::Ptr{$elty}) - if transA=='N'; lda=M; else; lda=K; end - if transB=='N'; ldb=K; else; ldb=N; end - ldc = M; - ccall((@blasfunc($(gemm)), libblas), Nothing, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, - Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, - Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, - Ref{BlasInt}), - transA, transB, M, N, K, - alpha, A, lda, B, ldb, beta, C, ldc) - end - end -end diff --git a/src/logsoftmax.jl b/src/logsoftmax.jl deleted file mode 100644 index 1611cab9d..000000000 --- a/src/logsoftmax.jl +++ /dev/null @@ -1,34 +0,0 @@ -using Base.Threads - -function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat) - @threads for j = 1:size(xs, 2) - @inbounds begin - xi_max = xs[1, j] - for i = 1:size(out, 1) - xi_max = max(xi_max, xs[i, j]) - end - s = zero(eltype(out)) - for i = 1:size(out, 1) - s += exp(xs[i, j] - xi_max) - end - for i = 1:size(out, 1) - out[i, j] = xs[i, j] - log(s) - xi_max - end - end - end - out -end - - -logsoftmax!(xs) = logsoftmax!(xs, xs) -logsoftmax(xs) = logsoftmax!(similar(xs), xs) - -∇logsoftmax(Δ, xs) = ∇softmax(Δ ./ max.(eps(eltype(xs)),softmax(xs)), xs) - -""" - logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs))) - -logsoftmax computes the log of softmax(xs) and it is more numerically stable -than softmax function in computing the cross entropy loss. -""" -logsoftmax diff --git a/src/numeric.jl b/src/numeric.jl deleted file mode 100644 index 92a5fd28f..000000000 --- a/src/numeric.jl +++ /dev/null @@ -1,89 +0,0 @@ -using Base.Math: @horner, significand_bits, exponent_raw_max, exponent_bias - -if VERSION < v"0.7.0-DEV.1430" - using Base.Math: fpinttype -else - using Base: uinttype -end - -# log_fast from -# https://github.com/musm/SLEEF.jl/blob/c9dcd2eb090d69ec40790f19798c5fef2aba2616/src/log.jl - -const MLN2 = 6.931471805599453094172321214581765680755001343602552541206800094933936219696955e-01 # log(2) - -@inline float2integer(d::Float64) = (reinterpret(Int64, d) >> significand_bits(Float64)) % Int -@inline float2integer(d::Float32) = (reinterpret(Int32, d) >> significand_bits(Float32)) % Int - -@inline function ilogb2k(d::T) where {T<:Union{Float32,Float64}} - (float2integer(d) & exponent_raw_max(T)) - exponent_bias(T) -end - -@inline function ldexp3k(x::T, e::Int) where {T<:Union{Float32,Float64}} - @static if VERSION < v"0.7.0-DEV.1430" - reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % fpinttype(T)) - else - reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % uinttype(T)) - end -end - -""" - log_fast(x) -Compute the natural logarithm of `x`. The inverse of the natural logarithm is -the natural expoenential function `exp(x)` -""" -function log_fast end - -let -global log_fast - -c8d = 0.153487338491425068243146 -c7d = 0.152519917006351951593857 -c6d = 0.181863266251982985677316 -c5d = 0.222221366518767365905163 -c4d = 0.285714294746548025383248 -c3d = 0.399999999950799600689777 -c2d = 0.6666666666667778740063 -c1d = 2.0 - -c5f = 0.2392828464508056640625f0 -c4f = 0.28518211841583251953125f0 -c3f = 0.400005877017974853515625f0 -c2f = 0.666666686534881591796875f0 -c1f = 2f0 - -global @inline log_fast_kernel(x::Float64) = @horner x c1d c2d c3d c4d c5d c6d c7d c8d -global @inline log_fast_kernel(x::Float32) = @horner x c1f c2f c3f c4f c5f - -function log_fast(d::T) where {T<:Union{Float32,Float64}} - o = d < realmin(T) - o && (d *= T(Int64(1) << 32) * T(Int64(1) << 32)) - - e = ilogb2k(d * T(1.0/0.75)) - m = ldexp3k(d, -e) - o && (e -= 64) - - x = (m - 1) / (m + 1) - x2 = x * x - - t = log_fast_kernel(x2) - - x = x * t + T(MLN2) * e - - isinf(d) && (x = T(Inf)) - (d < 0 || isnan(d)) && (x = T(NaN)) - d == 0 && (x = -T(Inf)) - - return x -end -end - -log_fast(x::Union{Int32,Int64}) = log_fast(float(x)) - -# Derivatives - -@init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin - function log_fast(d::ForwardDiff.Dual{T,<:Union{Float32,Float64}}) where T - x = ForwardDiff.value(d) - Dual{T}(log_fast(x), inv(x) * ForwardDiff.partials(d)) - end -end diff --git a/src/pooling.jl b/src/pooling.jl new file mode 100644 index 000000000..e349e82e4 --- /dev/null +++ b/src/pooling.jl @@ -0,0 +1,127 @@ +export maxpool, maxpool!, meanpool, meanpool!, ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool! + +## Pooling API +# +# We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d, +# 2d and 3d pooling, based on the rank of the input tensors, in both mutating and +# non-mutating auto-allocating variants: +# - Pooling: +# - maxpool(x, pdims) +# - maxpool!(y, x, pdims) +# - meanpool(x, pdims) +# - meanpool!(y, x, pdims) +# - Pooling input backprop +# - ∇maxpool(dy, pdims) +# - ∇maxpool!(dx, dy, pdims) +# - ∇meanpool(dy, pdims) +# - ∇meanpool!(dx, dy, pdims) +# +# All methods require a `PoolDims` object to define the dimensions and optional +# elements of the convolution (stride, dilation, etc...), which is easily constructable +# through something like `PoolDims(x, w)`. + + +# First, we will define mappings from the generic API names to our accelerated backend +# implementations. At the moment this is only the direct implementation, however this +# exists here so that other packages (NNPACK, MAGMA, etc...) can override this easily. +for (front_name, backend) in ( + # This maps from public, front-facing name, to internal backend name + :maxpool => :direct, + :meanpool => :direct, + ) + + # We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling + @eval begin + function $(Symbol("$(front_name)!"))( + y::AbstractArray{T,5}, x::AbstractArray{T,5}, + pdims::PoolDims; kwargs...) where {T} + $(Symbol("$(front_name)_$(backend)!"))(y, x, pdims; kwargs...) + end + end +end + +# Do the same for backprops +for (front_name, backend) in ( + :∇maxpool => :direct, + :∇meanpool => :direct, + ) + @eval begin + function $(Symbol("$(front_name)!"))( + dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, + y::AbstractArray{T,5}, x::AbstractArray{T,5}, + pdims::PoolDims; kwargs...) where {T} + $(Symbol("$(front_name)_$(backend)!"))(dx, dy, y, x, pdims; kwargs...) + end + end +end + + +# Our strategy for pooling is to reshape to an array with three spatial dimensions, which +# makes things MUCH EASIER for us on the backend side, and is in general pretty fast, +# since we can specialize on sizes. +for front_name in (:maxpool, :meanpool) + for backend in (Symbol(), :_direct) + for N in (3, 4) + @eval begin + function $(Symbol("$(front_name)$(backend)!"))( + y::AbstractArray{T,$N}, x::AbstractArray{T,$N}, + pdims::PoolDims; kwargs...) where {T} + $(Symbol("$(front_name)$(backend)!"))( + insert_singleton_spatial_dimension(y, $(5 - N)), + insert_singleton_spatial_dimension(x, $(5 - N)), + insert_singleton_spatial_dimension(pdims, $(5 - N)); + kwargs... + ) + + # We explicitly return `y` here, because the backend call + # itself may return a reshaped view, which we don't want. + return y + end + + # backprops too + function $(Symbol("∇$(front_name)$(backend)!"))( + dx::AbstractArray{T,$N}, dy::AbstractArray{T,$N}, + y::AbstractArray{T,$N}, x::AbstractArray{T,$N}, + pdims::PoolDims; kwargs...) where {T} + $(Symbol("∇$(front_name)$(backend)!"))( + insert_singleton_spatial_dimension(dx, $(5 - N)), + insert_singleton_spatial_dimension(dy, $(5 - N)), + insert_singleton_spatial_dimension(y, $(5 - N)), + insert_singleton_spatial_dimension(x, $(5 - N)), + insert_singleton_spatial_dimension(pdims, $(5 - N)); + kwargs... + ) + + # We explicitly return `dx` here, because the backend call + # itself may return a reshaped view, which we don't want. + return dx + end + end + end + end +end + + +# Finally, let's generate auto-allocating versions of all our functions, for all backends: +for backend in (Symbol(), :_direct, :_im2col) + # First make auto-allocating versions of the basic pooling calls: + for name in (:maxpool, :meanpool) + @eval begin + @timeit_debug to function $(Symbol("$(name)$(backend)"))( + x::AbstractArray{xT,N}, + pdims::PoolDims; kwargs...) where {xT, N} + y = zeros(xT, output_size(pdims)..., channels_out(pdims), size(x, N)) + return $(Symbol("$(name)$(backend)!"))(y, x, pdims; kwargs...) + end + + # Backprops too + @timeit_debug to function $(Symbol("∇$(name)$(backend)"))( + dy::AbstractArray{T,N}, y::AbstractArray{T,N}, + x::AbstractArray{T,N}, pdims::PoolDims; + kwargs...) where {T, N} + dx = zeros(T, input_size(pdims)..., channels_in(pdims), size(dy, N)) + return $(Symbol("∇$(name)$(backend)!"))(dx, dy, y, x, pdims; kwargs...) + end + end + end +end diff --git a/src/softmax.jl b/src/softmax.jl index 06acfa010..f58ae8f3b 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -1,40 +1,5 @@ -using Base.Threads - -function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where T<:AbstractFloat - @threads for j = 1:size(xs, 2) - @inbounds begin - # out[end, :] .= maximum(xs, 1) - out[end, j] = xs[end, j] - for i = 1:size(xs, 1) - out[end, j] = max(out[end, j], xs[i, j]) - end - # out .= exp(xs .- out[end, :]) - for i = 1:size(out, 1) - out[i, j] = exp(xs[i, j] - out[end, j]) - end - # out ./= sum(out, 1) - s = zero(eltype(out)) - for i = 1:size(out, 1) - s += out[i, j] - end - for i = 1:size(out, 1) - out[i, j] /= s - end - end - end - return out -end - -softmax!(xs) = softmax!(xs, xs) -softmax(xs) = softmax!(similar(xs), xs) - -function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat) - sf = softmax(xs) - out .= sf .* (Δ .- sum(Δ .*sf, dims = 1)) -end - -∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs) -∇softmax(Δ, xs) = ∇softmax!(similar(Δ), Δ, xs) +export softmax, softmax!, ∇softmax, ∇softmax!, + logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax! """ softmax(xs) = exp.(xs) ./ sum(exp.(xs)) @@ -52,4 +17,69 @@ independent. 0.244728 0.665241 """ -softmax +softmax(xs) = softmax!(similar(xs), xs) + +function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T} + @inbounds for j = 1:size(xs, 2) + # First, store column-wise maximum in the last element of `out` + out[end, j] = xs[end, j] + @inbounds for i = 1:(size(xs, 1) - 1) + out[end, j] = max(out[end, j], xs[i, j]) + end + + # Subtract the column-wise maximums to normalize, take exp() + # out .= exp(xs .- out[end, :]) + @inbounds for i = 1:size(out, 1) + out[i, j] = exp(xs[i, j] - out[end, j]) + end + + # Normalize by sum of the entire thing + # out ./= sum(out, 1) + s = T(0) + @inbounds for i = 1:size(out, 1) + s += out[i, j] + end + @inbounds for i = 1:size(out, 1) + out[i, j] /= s + end + end + return out +end + +function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat) + sf = softmax(xs) + out .= sf .* (Δ .- sum(Δ .*sf, dims = 1)) +end + +∇softmax(Δ, xs) = ∇softmax!(similar(Δ), Δ, xs) +∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs) + + +""" + logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs))) + +`logsoftmax(xs)` computes the log of `softmax(xs)`, but in a more numerically stable +way than directly taking the log of the the softmax function, which is commonly used in +computing cross entropy loss. +""" +logsoftmax(xs) = logsoftmax!(similar(xs), xs) +function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat) + for j = 1:size(xs, 2) + @inbounds begin + xi_max = xs[1, j] + for i = 1:size(out, 1) + xi_max = max(xi_max, xs[i, j]) + end + s = zero(eltype(out)) + for i = 1:size(out, 1) + s += exp(xs[i, j] - xi_max) + end + for i = 1:size(out, 1) + out[i, j] = xs[i, j] - log(s) - xi_max + end + end + end + return out +end +∇logsoftmax(Δ, xs) = ∇softmax(Δ ./ max.(eps(eltype(xs)),softmax(xs)), xs) +∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs) \ No newline at end of file diff --git a/test/activation.jl b/test/activation.jl index 3a9d3221d..450e0d8dc 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -1,3 +1,5 @@ +using NNlib, Test + ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign]; function test_value_float_precision_preserving(a) @@ -22,83 +24,105 @@ function test_value_int_input_forces_float64(a) end end -# if Base.find_in_path("ForwardDiff") ≠ nothing -# using ForwardDiff -# function test_value_duals(a) -# @testset "$(a): " begin -# for T in [Float32, Float64, Int32, Int64] -# val = @inferred a(ForwardDiff.Dual(float(T(1)), one(float(T)))) -# @test typeof(val) == ForwardDiff.Dual{Nothing,float(T),1} -# end -# end -# end -# -# test_value_duals.(ACTIVATION_FUNCTIONS) -# end - @testset "Activation Functions" begin - - @test σ(0.0) == 0.5 - @test relu(0.0) == 0.0 - @test leakyrelu(0.0) == 0.0 - @test elu(0.0) == 0.0 - @test gelu(0.0) == 0.0 - @test swish(0.0) == 0.0 - @test softplus(0.0) ≈ log(2.0) - @test softplus(1e8) ≈ 1e8 - @test softplus(-1e8) ≈ 0.0 - @test softsign(0.0) == 0.0 - @test selu(0.0) == 0.0 - - @test σ(1.0) == 1.0 / (1.0 + exp(-1.0)) - @test relu(1.0) == 1.0 - @test leakyrelu(1.0) == 1.0 - @test elu(1.0) == 1.0 - @test gelu(1.0) == 0.8411919906082768 - @test swish(1.0) == 1.0 / (1.0 + exp(-1.0)) - @test softplus(1.0) ≈ log(exp(1.0) + 1.0) - @test softsign(1.0) == 0.5 - @test selu(1.0) == 1.0507009873554804934193349852946 - - @test σ(-1.0) == 1.0 / (1.0 + exp(1.0)) - @test relu(-1.0) == 0.0 - @test leakyrelu(-1.0) == -0.01 - @test elu(-1.0) == exp(-1.0) - 1.0 - @test gelu(-1.0) == -0.15880800939172324 - @test swish(-1.0) == -1.0 / (1.0 + exp(1.0)) - @test softplus(-1.0) ≈ log(exp(-1.0) + 1.0) - @test softsign(-1.0) == -0.5 - @test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0) - - @testset "Float inference" begin - test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS) - end - - @testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin - test_value_int_input_forces_float64.(filter(x -> x != relu, ACTIVATION_FUNCTIONS)) - - @testset "relu: " begin - # relu doesn't have to force floating point outputs - @test typeof(relu(Int64(1))) == Int64 - @test typeof(relu(Int32(1))) == Int32 + @test σ(0.0) == 0.5 + @test relu(0.0) == 0.0 + @test leakyrelu(0.0) == 0.0 + @test elu(0.0) == 0.0 + @test gelu(0.0) == 0.0 + @test swish(0.0) == 0.0 + @test softplus(0.0) ≈ log(2.0) + @test softplus(1e8) ≈ 1e8 + @test softplus(-1e8) ≈ 0.0 + @test softsign(0.0) == 0.0 + @test selu(0.0) == 0.0 + + @test σ(1.0) == 1.0 / (1.0 + exp(-1.0)) + @test relu(1.0) == 1.0 + @test leakyrelu(1.0) == 1.0 + @test elu(1.0) == 1.0 + @test gelu(1.0) == 0.8411919906082768 + @test swish(1.0) == 1.0 / (1.0 + exp(-1.0)) + @test softplus(1.0) ≈ log(exp(1.0) + 1.0) + @test softsign(1.0) == 0.5 + @test selu(1.0) == 1.0507009873554804934193349852946 + + @test σ(-1.0) == 1.0 / (1.0 + exp(1.0)) + @test relu(-1.0) == 0.0 + @test leakyrelu(-1.0) == -0.01 + @test elu(-1.0) == exp(-1.0) - 1.0 + @test gelu(-1.0) == -0.15880800939172324 + @test swish(-1.0) == -1.0 / (1.0 + exp(1.0)) + @test softplus(-1.0) ≈ log(exp(-1.0) + 1.0) + @test softsign(-1.0) == -0.5 + @test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0) + + @testset "Float inference" begin + test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS) end - end + @testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin + test_value_int_input_forces_float64.(filter(x -> x != relu, ACTIVATION_FUNCTIONS)) - xs = rand(5,5) - - @test all(sum(softmax(xs), dims = 1) .≈ 1) + @testset "relu: " begin + # relu doesn't have to force floating point outputs + @test typeof(relu(Int64(1))) == Int64 + @test typeof(relu(Int32(1))) == Int32 + end + end - @test sum(softmax(vec(xs))) ≈ 1 + @testset "softmax" begin + xs = rand(5,5) + @test all(sum(softmax(xs), dims = 1) .≈ 1) + @test sum(softmax(vec(xs))) ≈ 1 + + xs = [-100_000, -100_000.] + @test softmax(xs) ≈ [0.5, 0.5] + @test logsoftmax(xs) ≈ log.([0.5, 0.5]) + + xs = rand(5) + @test softmax(xs) ≈ exp.(xs) ./ sum(exp.(xs)) + @test logsoftmax(xs) ≈ log.(softmax(xs)) + + xs = Float32[1, 2, 3000.] + @test logsoftmax(xs) ≈ [-2999, -2998, 0] + + xs = Float32[1 2 3; 1000 2000 3000] + @test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.] + + @test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs)) + @test NNlib.∇softmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs)) + + # These values precalculated using PyTorch's nn.LogSoftmax + xs = [ + -0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842; + 0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663; + -1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977 + ] + ys = [ + 0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273; + -0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428; + 0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697 + ] + @test isapprox(NNlib.∇logsoftmax(ones(size(xs)), xs), ys; rtol = 1e-6) + @test isapprox(NNlib.∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6) + end - @testset "elu" begin - @test elu(42) == 42 - @test elu(42.) == 42. + @testset "elu" begin + @test elu(42) == 42 + @test elu(42.) == 42. - @test elu(-4) ≈ (exp(-4) - 1) - end + @test elu(-4) ≈ (exp(-4) - 1) + end - @test leakyrelu( 0.4,0.3) ≈ 0.4 - @test leakyrelu(-0.4,0.3) ≈ -0.12 + @test leakyrelu( 0.4,0.3) ≈ 0.4 + @test leakyrelu(-0.4,0.3) ≈ -0.12 + @testset "logsigmoid" begin + xs = randn(10,10) + @test logsigmoid.(xs) ≈ log.(sigmoid.(xs)) + for T in [:Float32, :Float64] + @eval @test logsigmoid.($T[-100_000, 100_000.]) ≈ $T[-100_000, 0.] + end + end end diff --git a/test/conv.jl b/test/conv.jl index e637ee2b7..6a3c593b3 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -1,311 +1,641 @@ -using NNlib: conv, crosscor, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwiseconv, ∇depthwiseconv_filter, ∇depthwiseconv_data - -@testset "conv2d" begin - x = reshape(Float64[1:20;], 5, 4, 1, 1) - w = reshape(Float64[1:4;], 2, 2, 1, 1) - w1 = reshape(Float64[1:6;], 2, 3, 1, 1) - w2 = reshape(Float64[1:6;], 3, 2, 1, 1) - - @test dropdims(conv(x, w1), dims = (3,4)) == [ - 95.0 200.0; - 116.0 221.0; - 137.0 242.0; - 158.0 263.0] - - @test dropdims(conv(x, w2), dims = (3,4)) == [ - 68.0 173.0 278.0; - 89.0 194.0 299.0; - 110.0 215.0 320.0] - - @test dropdims(conv(x, w), dims = (3,4)) == [ - 29 79 129; - 39 89 139; - 49 99 149; - 59 109 159.] - - @test dropdims(conv(view(x, :, :, :, :), w), dims = (3,4)) == [ - 29 79 129; - 39 89 139; - 49 99 149; - 59 109 159.] - - @test dropdims(crosscor(x, w), dims = (3,4)) == [ - 51 101 151; - 61 111 161; - 71 121 171; - 81 131 181.] - - @test dropdims(conv(Float32.(x), Float32.(w)), dims=(3,4)) == Float32.([ - 29 79 129; - 39 89 139; - 49 99 149; - 59 109 159.]) - - @test dropdims(conv(x, w; stride=2), dims = (3,4)) == [ - 29 129; - 49 149.] - - @test dropdims(conv(Float32.(x), Float32.(w); stride=2), dims = (3,4)) == Float32.([ - 29 129; - 49 149.]) - - @test dropdims(conv(x, w; pad=1), dims = (3,4)) == [ - 1.0 9.0 29.0 49.0 48.0; - 4.0 29.0 79.0 129.0 115.0; - 7.0 39.0 89.0 139.0 122.0; - 10.0 49.0 99.0 149.0 129.0; - 13.0 59.0 109.0 159.0 136.0; - 10.0 40.0 70.0 100.0 80.0 - ] - - @test dropdims(conv(Float32.(x), Float32.(w); pad=1), dims = (3,4)) == Float32.([ - 1.0 9.0 29.0 49.0 48.0; - 4.0 29.0 79.0 129.0 115.0; - 7.0 39.0 89.0 139.0 122.0; - 10.0 49.0 99.0 149.0 129.0; - 13.0 59.0 109.0 159.0 136.0; - 10.0 40.0 70.0 100.0 80.0 - ]) - - @test dropdims(conv(x, w; dilation=2), dims = (3,4)) == [ - 48 98; - 58 108; - 68 118.] - - # NaN tests for dilation forward pass - - ys = [] - for idx in 1:1000 - push!(ys, conv(x, w; dilation=2)) - end - @test !any([any(isnan.(ys[idx])) for idx in 1:1000]) - - # for gradients, check only size - # correctness of gradients is cross-checked with CUDNN.jl - # (it's assumed convolution code won't change often) - - @test size(∇conv_filter(reshape(rand(4,3), 4, 3, 1, 1), x)) == size(w) - @test size(∇conv_data(reshape(rand(4,3), 4, 3, 1, 1), w)) == size(x) - - # Test that stride/pad work backward as well - y = conv(x, w; stride=2, pad=1, dilation=2) - @test size(y) == (3, 2, 1, 1) - @test size(∇conv_filter(y, x; size=size(w), stride=2, pad=1, dilation=2)) == size(w) - @test size(∇conv_data(y, w; size=size(x), stride=2, pad=1, dilation=2)) == size(x) - - # NaN tests for dilation backward pass: filters - dy = randn(size(ys[1])) - dws = [] - for idx in 1:1000 - push!(dws, ∇conv_filter(dy, x; size=size(w), dilation=2)) - end - - # NaN tests for dilation backward pass: input - dxs = [] - for idx in 1:1000 - push!(dxs, ∇conv_data(dy, w; size=size(x), dilation=2)) - end - - @test !any([any(isnan.(dws[idx])) for idx in 1:1000]) - @test !any([any(isnan.(dxs[idx])) for idx in 1:1000]) - +using NNlib, Test +using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier, + stride, padding, dilation, flipkernel, output_size + +@testset "ConvDims" begin + for T in (DenseConvDims, DepthwiseConvDims) + @testset "$(T)" begin + x = randn(5,4,3,2) + + if T == DenseConvDims + w = randn(1,2,3,4) + elseif T == DepthwiseConvDims + w = randn(1,2,4,3) + end + + # First, getters: + cdims = T(x, w) + @test input_size(cdims) == size(x)[1:2] + @test kernel_size(cdims) == size(w)[1:2] + @test channels_in(cdims) == size(x, 3) + @test stride(cdims) == (1,1) + @test dilation(cdims) == (1,1) + @test padding(cdims) == (0,0,0,0) + @test flipkernel(cdims) == false + @test output_size(cdims) == (5,3) + + # Special-case channel output tests + if T == DenseConvDims + @test channels_out(cdims) == size(w, 4) + elseif T == DepthwiseConvDims + @test channel_multiplier(cdims) == size(w, 3) + @test channels_out(cdims) == size(w,3)*size(w,4) + end + + # Next, scalar settings: + cdims = T(x, w; stride=2, dilation=2, padding=3, flipkernel=true) + @test stride(cdims) == (2,2) + @test dilation(cdims) == (2,2) + @test padding(cdims) == (3,3,3,3) + @test flipkernel(cdims) == true + @test output_size(cdims) == (6,4) + + # Next, tuple settings + cdims = T(x, w; stride=(1, 2), dilation=(1, 2), padding=(0,1)) + @test stride(cdims) == (1,2) + @test dilation(cdims) == (1,2) + @test padding(cdims) == (0,0,1,1) + @test output_size(cdims) == (5,2) + + # Special case for 4-d padding spec: + cdims = T(x, w; padding=(1,2,3,4)) + @test padding(cdims) == (1,2,3,4) + @test output_size(cdims) == (8,10) + + # Make sure we throw on invalid settings: + # Invalid dimensionality of settings: + @test_throws DimensionMismatch T(x, w; stride=(1,)) + @test_throws DimensionMismatch T(x, w; stride=(1, 1, 1)) + @test_throws DimensionMismatch T(x, w; padding=(1, 1, 1)) + @test_throws DimensionMismatch T(x, w; padding=(1, 1, 1, 1, 1)) + @test_throws DimensionMismatch T(x, w; dilation=(1,)) + @test_throws DimensionMismatch T(x, w; dilation=(1, 1, 1)) + # Dilation will cause us to reach beyond the end of input + padding here: + @test_throws DimensionMismatch T(x, w; dilation=(1, 5)) + # Channel mismatch: + if T == DenseConvDims + @test_throws DimensionMismatch T(x, w[:,:,1:1,:]) + elseif T == DepthwiseConvDims + @test_throws DimensionMismatch T(x, w[:,:,:,1:1]) + end + end + end end -@testset "depthwiseconv2d" begin - x = reshape(Float64[1:18;], 3, 3, 2, 1) - w = reshape(Float64[1:16;], 2, 2, 2, 2) - - @test depthwiseconv(x, w)[:] == [23.0, 33.0, 53.0, 63.0, 71.0, 97.0, 149.0, 175.0, 497.0, 539.0, 623.0, 665.0, 689.0, 747.0, 863.0, 921.0] - - @test depthwiseconv(x, w, stride = 2, pad = 1)[:] == [1.0, 7.0, 19.0, 63.0, 5.0, 27.0, 63.0, 175.0, 90.0, 218.0, 287.0, 665.0, 130.0, 310.0, 403.0, 921.0] - - @test depthwiseconv(x, w, stride = 2)[:] == [23.0, 71.0, 497.0, 689.0] - - @test depthwiseconv(x, w, pad = 1)[:] == [1.0, 4.0, 7.0, 6.0, 7.0, 23.0, 33.0, 24.0, 19.0, 53.0, 63.0, 42.0, 21.0, 52.0, 59.0, 36.0, 5.0, 16.0, 27.0, 18.0, 27.0, 71.0, 97.0, 60.0, 63.0, 149.0, 175.0, 102.0, 49.0, 112.0, 127.0, 72.0, 90.0, 199.0, 218.0, 120.0, 227.0, 497.0, 539.0, 294.0, 287.0, 623.0, 665.0, 360.0, 176.0, 379.0, 402.0, 216.0, 130.0, 283.0, 310.0, 168.0, 319.0, 689.0, 747.0, 402.0, 403.0, 863.0, 921.0, 492.0, 240.0, 511.0, 542.0, 288.0] - - # the correctness of the gradients are being verified by calling - # the corresponding counvolution gradients - - dy = reshape(Float64[1:16;], 2,2,4,1) - local z = ∇depthwiseconv_data(dy,x,w) - for i in 1:2 - X = copy(x[:,:,i:i,:]); - W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3])); - DY = copy(dy[:,:,2i-1:2i,:]); - res = ∇conv_data(DY,W;size=size(X)) - @test dropdims(z[:,:,i:i,:], dims=(3,4)) == dropdims(res, dims=(3,4)) +conv_answer_dict = Dict( + # Known-good answers for 1d convolution operations + 1 => Dict( + "y_pad" => [1, 4, 7, 10, 13, 10.], + "y_dil" => [5, 8, 11.], + "y_flip" => [5, 8, 11, 14.], + + "dx" => [ 8, 18, 27, 36, 13.], + "dx_stride" => [ 8, 4, 20, 10, 0.], + "dx_pad" => [ 9, 18, 27, 36, 33.], + "dx_dil" => [10, 16, 27, 8, 11.], + "dx_flip" => [ 5, 18, 27, 36, 28.], + + "dw" => [134, 100.], + "dw_stride" => [ 48, 34.], + "dw_pad" => [135, 150.], + "dw_dil" => [102, 54.], + "dw_flip" => [110, 148.], + ), + + # Known-good answers for 2d convolution operations + 2 => Dict( + "y_pad" => [ + 1 9 29 49 48; + 4 29 79 129 115; + 7 39 89 139 122; + 10 49 99 149 129; + 13 59 109 159 136; + 10 40 70 100 80. + ], + "y_dil" => [ + 48 98; + 58 108; + 68 118. + ], + "y_flip" => [ + 51 101 151; + 61 111 161; + 71 121 171; + 81 131 181. + ], + + "dx" => [ + 116 374 674 258; + 243 700 1200 407; + 313 800 1300 437; + 383 900 1400 467; + 177 386 586 159. + ], + "dx_stride" => [ + 116 58 516 258; + 87 29 387 129; + 196 98 596 298; + 147 49 447 149; + 0 0 0 0. + ], + "dx_pad" => [ + 152 470 850 911; + 261 700 1200 1240; + 340 800 1300 1319; + 419 900 1400 1398; + 370 746 1126 1087. + ], + "dx_dil" => [ + 192 392 96 196; + 232 432 116 216; + 416 766 184 334; + 174 324 58 108; + 204 354 68 118. + ], + "dx_flip" => [ + 51 254 454 453; + 163 700 1200 1087; + 193 800 1300 1157; + 223 900 1400 1227; + 162 586 886 724. + ], + + "dw" => [ + 17378 11738; + 16250 10610. + ], + "dw_stride" => [ + 5668 3888; + 5312 3532. + ], + "dw_pad" => [ + 18670 22550; + 19850 23430. + ], + "dw_dil" => [ + 8632 3652; + 7636 2656. + ], + "dw_flip" => [ + 12590 19550; + 13982 20942. + ], + ), + + # Known-good answers for 3d convolution operations (these are getting rather large) + 3 => Dict( + "y_pad" => reshape([ + 1, 4, 7, 10, 13, 10, 9, 29, 39, 49, 59, 40, 29, 79, 89, 99, 109, 70, 49, 129, + 139, 149, 159, 100, 48, 115, 122, 129, 136, 80, 26, 80, 94, 108, 122, 80, 126, + 322, 358, 394, 430, 260, 206, 502, 538, 574, 610, 360, 286, 682, 718, 754, 790, + 460, 220, 502, 524, 546, 568, 320, 146, 360, 374, 388, 402, 240, 446, 1042, 1078, + 1114, 1150, 660, 526, 1222, 1258, 1294, 1330, 760, 606, 1402, 1438, 1474, 1510, + 860, 420, 942, 964, 986, 1008, 560, 205, 456, 467, 478, 489, 270, 517, 1133, 1159, + 1185, 1211, 660, 577, 1263, 1289, 1315, 1341, 730, 637, 1393, 1419, 1445, 1471, + 800, 392, 847, 862, 877, 892, 480. + ], (6,5,4)), + "y_dil" => reshape([608, 644, 680, 788, 824, 860.], (3,2,1)), + "y_flip" => reshape([ + 686, 722, 758, 794, 866, 902, 938, 974, 1046, 1082, 1118, 1154, 1406, 1442, + 1478, 1514, 1586, 1622, 1658, 1694, 1766, 1802, 1838, 1874. + ], (4,3,2)), + + "dx" => reshape([ + 2576, 5118, 5658, 6198, 3010, 5948, 11576, 12512, 13448, 6420, 8468, 16256, + 17192, 18128, 8580, 4092, 7718, 8114, 8510, 3950, 9624, 18316, 19108, 19900, + 9340, 18680, 34992, 36288, 37584, 17320, 22280, 41472, 42768, 44064, 20200, + 9776, 17756, 18260, 18764, 8340, 4168, 7438, 7690, 7942, 3450, 6972, 11896, + 12256, 12616, 5140, 8052, 13696, 14056, 14416, 5860, 2804, 4278, 4386, 4494, + 1510. + ], (5,4,3)), + "dx_stride" => reshape([ + 2576, 2254, 3152, 2758, 0, 1932, 1610, 2364, 1970, 0, 5456, 4774, 6032, + 5278, 0, 4092, 3410, 4524, 3770, 0, 1288, 966, 1576, 1182, 0, 644, 322, + 788, 394, 0, 2728, 2046, 3016, 2262, 0, 1364, 682, 1508, 754, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0. + ], (5,4,3)), + "dx_pad" => reshape([ + 4220, 6343, 7116, 7889, 6550, 8490, 12276, 13312, 14348, 11606, 12350, + 17456, 18492, 19528, 15546, 11989, 16664, 17469, 18274, 14333, 16200, + 22628, 23616, 24604, 19392, 25336, 34992, 36288, 37584, 29320, 30216, + 41472, 42768, 44064, 34200, 26236, 35664, 36652, 37640, 28940, 22816, + 30831, 31636, 32441, 24794, 32522, 43668, 44704, 45740, 34742, 36462, + 48848, 49884, 50920, 38602, 29501, 39264, 40037, 40810, 30733. + ], (5,4,3)), + "dx_dil" => reshape([ + 4864, 5152, 9696, 4508, 4760, 6304, 6592, 12396, 5768, 6020, 3648, + 3864, 7120, 3220, 3400, 4728, 4944, 9100, 4120, 4300, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2432, 2576, 4544, 1932, 2040, + 3152, 3296, 5804, 2472, 2580, 1216, 1288, 1968, 644, 680, 1576, 1648, + 2508, 824, 860. + ], (5,4,3)), + "dx_flip" => reshape([ + 686, 2094, 2202, 2310, 1588, 2924, 7544, 7904, 8264, 5124, 3644, 9344, + 9704, 10064, 6204, 3138, 7430, 7682, 7934, 4616, 4836, 11980, 12484, + 12988, 7792, 14936, 34992, 36288, 37584, 21640, 17816, 41472, 42768, + 44064, 25240, 12620, 28412, 29204, 29996, 16728, 7030, 15646, 16042, + 16438, 9084, 17772, 38968, 39904, 40840, 22276, 19932, 43648, 44584, + 45520, 24796, 12362, 26742, 27282, 27822, 14992. + ], (5,4,3)), + + "dw" => reshape([1.058184e6, 1.0362e6, 948264, 926280, + 618504, 596520, 508584, 486600], (2,2,2)), + "dw_stride" => reshape([ 74760, 72608, 64000, 61848, + 31720, 29568, 20960, 18808.], (2,2,2)), + "dw_pad" => reshape([1.26055e6, 1.30805e6, 1.40327e6, 1.44923e6, + 1.73731e6, 1.77589e6, 1.83259e6, 1.86731e6], (2,2,2)), + "dw_dil" => reshape([ 250320, 241512, 206280, 197472, + 74160, 65352, 30120, 21312.], (2,2,2)), + "dw_flip" => reshape([ 639480, 670200, 793080, 823800, + 1.25388e6, 1.2846e6, 1.40748e6, 1.4382e6], (2,2,2)), + ), +) + +@testset "Dense Convolution" begin + # Start with some easy-to-debug cases that we have worked through and _know_ work + for rank in (1,2,3) + @testset "conv$(rank)d" begin + # Pull out known-good answers for y = conv(x, w) + y_pad = conv_answer_dict[rank]["y_pad"] + y_dil = conv_answer_dict[rank]["y_dil"] + y_flip = conv_answer_dict[rank]["y_flip"] + + # We can always derive y_plain and y_stride from the other answers. + y_plain = y_pad[((2:(size(y_pad,idx)-1)) for idx in 1:rank)...] + y_stride = y_pad[((2:2:(size(y_pad,idx)-1)) for idx in 1:rank)...] + + # Same for dx and dw: + dx = conv_answer_dict[rank]["dx"] + dx_stride = conv_answer_dict[rank]["dx_stride"] + dx_pad = conv_answer_dict[rank]["dx_pad"] + dx_dil = conv_answer_dict[rank]["dx_dil"] + dx_flip = conv_answer_dict[rank]["dx_flip"] + + dw = conv_answer_dict[rank]["dw"] + dw_stride = conv_answer_dict[rank]["dw_stride"] + dw_pad = conv_answer_dict[rank]["dw_pad"] + dw_dil = conv_answer_dict[rank]["dw_dil"] + dw_flip = conv_answer_dict[rank]["dw_flip"] + + # We generate x and w from the shapes we know they must be + x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) + w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1) + + # A "drop channels and batch dimension" helper + ddims(x) = dropdims(x, dims=(rank+1, rank+2)) + + for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct) + @testset "$(conv)" begin + # First, your basic convolution with no parameters + cdims = DenseConvDims(x, w) + @test ddims(conv(x, w, cdims)) == y_plain + + # Next, test convolution on views and alternate datatypes: + @test ddims(conv(view(x, repeat([:], ndims(x))...), w, cdims)) == y_plain + @test ddims(conv(Float32.(x), Float32.(w), cdims)) == Float32.(y_plain) + + # Next, introduce stride: + cdims = DenseConvDims(x, w; stride=2) + @test ddims(conv(x, w, cdims)) == y_stride + + # Next, introduce dilation: + cdims = DenseConvDims(x, w; dilation=2) + @test ddims(conv(x, w, cdims)) == y_dil + + # Next, introduce padding: + cdims = DenseConvDims(x, w; padding=1) + @test ddims(conv(x, w, cdims)) == y_pad + + # Next, test crosscor/conv with a flipped kernel + cdims = DenseConvDims(x, w; flipkernel=true) + @test ddims(conv(x, w, cdims)) == y_flip + end + end + + # Test all implementations/interfaces + for (∇conv_filter, ∇conv_data) in ( + (NNlib.∇conv_filter, NNlib.∇conv_data), + (NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col), + (NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct), + ) + @testset "$(∇conv_filter)/$(∇conv_data)" begin + # First, your basic convolution with no parameters + cdims = DenseConvDims(x, w) + dy = NNlib.conv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw + @test ddims(∇conv_data(dy, w, cdims)) == dx + + # Next, test convolution on views and alternate datatypes: + @test ddims(∇conv_filter(x, view(dy, repeat([:], ndims(dy))...), cdims)) == dw + @test ddims(∇conv_data(view(dy, repeat([:], ndims(dy))...), w, cdims)) == dx + + @test ddims(∇conv_filter(Float32.(x), Float32.(dy), cdims)) == dw + @test ddims(∇conv_data(Float32.(dy), Float32.(w), cdims)) == dx + + # Next, introduce stride: + cdims = DenseConvDims(x, w; stride=2) + dy = NNlib.conv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_stride + @test ddims(∇conv_data(dy, w, cdims)) == dx_stride + + # Next, introduce dilation: + cdims = DenseConvDims(x, w; dilation=2) + dy = NNlib.conv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_dil + @test ddims(∇conv_data(dy, w, cdims)) == dx_dil + + # Next, introduce padding: + cdims = DenseConvDims(x, w; padding=1) + dy = NNlib.conv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_pad + @test ddims(∇conv_data(dy, w, cdims)) == dx_pad + + # Next, test crosscor/conv with a flipped kernel + cdims = DenseConvDims(x, w; flipkernel=true) + dy = NNlib.conv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_flip + @test ddims(∇conv_data(dy, w, cdims)) == dx_flip + end + end + end end - z = ∇depthwiseconv_filter(dy, x, w) - for i in 1:2 - X = copy(x[:,:,i:i,:]); - W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3])) - DY = copy(dy[:,:,2i-1:2i,:]) - res = ∇conv_filter(DY,X; size=size(W)) - @test dropdims(z[:,:,:,i:i]; dims=(4)) == dropdims(res; dims=(3)) + @testset "fuzzing" begin + if get(ENV,"NNLIB_TEST_FUZZING","false") != "true" + @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them") + return + end + @info("Starting Convolutional fuzzing tests; this can take a few minutes...") + # Now that we're fairly certain things are working, let's fuzz things a little bit: + for x_size in ( + # 1d tests + (1,), (3,), (7,), + # 2d tests + (1, 3), (3, 3), (12, 3), (20, 17), + # 3d tests + (1, 1, 3), (3, 5, 4), (20, 17, 14), + ), + C_in in (1, 3), + batch in (1, 5) + + # Allocate x in this outer loop to save on allocations and speed things up + x = rand(x_size..., C_in, batch) + dx_direct = similar(x) + dx_im2col = similar(x) + + for w_size in ( + (1,), (3,), (7,), + (1,1), (1,3), (3,4), (7, 4), + (1,1,1), (1,1,3,), (3,4,3), (7,3,2)), + C_out in (1, 4) + + # Give some output to the user that something is in fact happening. + print(".") + + # Allocate w in this outer loop to save on allocations and speed things up + w = rand(w_size..., C_in, C_out) + dw_direct = similar(w) + dw_im2col = similar(w) + + for S_size in (1, 2, 4, (1,2), (4,1), (2,1,4)), + P_size in (0, 1, 2, (0,3,0,3), (4,1,4,2), (1,2,3,4,5,6)), + D_size in (1, 2, 4, (1,2), (3,2), (4,2,3)) + + # Skip tests that are impossible due to mismatched sizes + try + DenseConvDims(x, w; + stride=S_size, padding=P_size, dilation=D_size, + ) + catch e + if isa(e, DimensionMismatch) || isa(e, MethodError) + continue + end + rethrow(e) + end + + # Do the actual convolution, comparing convolution implementations + cdims = DenseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size) + + # We use mutating calls with explicitly different initial values, so as + # to be sure to catch when we're leaving pieces of the output untouched. + y_direct = ones(output_size(cdims)..., C_out, batch) .* 666.666 + y_im2col = ones(output_size(cdims)..., C_out, batch) .* 777.777 + + # Do the convolutions + NNlib.conv_direct!(y_direct, x, w, cdims) + NNlib.conv_im2col!(y_im2col, x, w, cdims) + + # Compare! + @test y_direct ≈ y_im2col + dy = y_im2col + + # Now push backwards; first for the filter. Again, we initialize our + # memory so that segments that never get touched are immediately noticable + fill!(dw_direct, 666.666) + fill!(dw_im2col, 777.777) + NNlib.∇conv_filter_direct!(dw_direct, x, dy, cdims) + NNlib.∇conv_filter_im2col!(dw_im2col, x, dy, cdims) + @test dw_direct ≈ dw_im2col + + # And then for the input + fill!(dx_direct, 666.666) + fill!(dx_im2col, 777.777) + NNlib.∇conv_data_direct!(dx_direct, dy, w, cdims) + NNlib.∇conv_data_im2col!(dx_im2col, dy, w, cdims) + @test dx_direct ≈ dx_im2col + end + end + end + println() end - - @test size(∇depthwiseconv_filter(rand(2,2,4,1), x, w)) == size(w) - @test size(∇depthwiseconv_data(rand(2,2,4,1), x, w)) == size(x) - - # Test for the stride/pad for backward pass - y = depthwiseconv(x,w,stride=2,pad=1) - @test size(y) == (2,2,4,1) - @test size(∇depthwiseconv_filter(rand(Float64, size(y)), x, w, stride=2, pad=1)) == size(w) - @test size(∇depthwiseconv_data(rand(Float64, size(y)), x, w, stride=2, pad=1)) == size(x) -end - -@testset "maxpool2d" begin - - x = reshape(Float64[1:20;], 5, 4, 1, 1) - - @test dropdims(maxpool(x, (2,2)), dims = (3,4)) == [7 17; 9 19] - @test dropdims(maxpool(x, (2,2); stride=(2,2)), dims = (3,4)) == [7 17; 9 19] - @test dropdims(maxpool(x, (2,2); pad=(1,1)), dims = (3,4)) == [ - 1.0 11.0 16.0; - 3.0 13.0 18.0; - 5.0 15.0 20.0; - ] - - # for gradients, check only size - # correctness of gradients is cross-checked with CUDNN.jl - # (it's assumed maxpooling code won't change often) - - y = maxpool(x, (2,2)) - dy = reshape(rand(2,2), 2, 2, 1, 1) - @test size(∇maxpool(dy, y, x, (2,2))) == size(x) - -end - - -@testset "conv3d" begin - - x = reshape(Float64[1:60;], 5, 4, 3, 1, 1) - w = reshape(Float64[1:8;], 2, 2, 2, 1, 1) - res = zeros(4,3,2) - res[:, :, 1] = [ - 322.0 502.0 682.0; - 358.0 538.0 718.0; - 394.0 574.0 754.0; - 430.0 610.0 790.0] - res[:, :, 2] = [ - 1042.0 1222.0 1402.0; - 1078.0 1258.0 1438.0; - 1114.0 1294.0 1474.0; - 1150.0 1330.0 1510.0] - @test dropdims(conv(x, w), dims = (4,5)) == res - - @test dropdims(conv(Float32.(x), Float32.(w)), dims = (4,5)) == Float32.(res) - - @test dropdims(conv(x, w; stride=2), dims = (3,4,5)) == [ - 322.0 682.0; - 394.0 754.0] - - @test dropdims(conv(Float32.(x), Float32.(w); stride=2), dims = (3,4,5)) == Float32.([ - 322.0 682.0; - 394.0 754.0]) - - res = zeros(6,5,4) - res[:, :, 1] = [ - 1.0 9.0 29.0 49.0 48.0; - 4.0 29.0 79.0 129.0 115.0; - 7.0 39.0 89.0 139.0 122.0; - 10.0 49.0 99.0 149.0 129.0; - 13.0 59.0 109.0 159.0 136.0; - 10.0 40.0 70.0 100.0 80.0] - res[:, :, 2] = [ - 26.0 126.0 206.0 286.0 220.0; - 80.0 322.0 502.0 682.0 502.0; - 94.0 358.0 538.0 718.0 524.0; - 108.0 394.0 574.0 754.0 546.0; - 122.0 430.0 610.0 790.0 568.0; - 80.0 260.0 360.0 460.0 320.0] - res[:, :, 3] = [ - 146.0 446.0 526.0 606.0 420.0; - 360.0 1042.0 1222.0 1402.0 942.0; - 374.0 1078.0 1258.0 1438.0 964.0; - 388.0 1114.0 1294.0 1474.0 986.0; - 402.0 1150.0 1330.0 1510.0 1008.0; - 240.0 660.0 760.0 860.0 560.0] - res[:, :, 4] = [ - 205.0 517.0 577.0 637.0 392.0; - 456.0 1133.0 1263.0 1393.0 847.0; - 467.0 1159.0 1289.0 1419.0 862.0; - 478.0 1185.0 1315.0 1445.0 877.0; - 489.0 1211.0 1341.0 1471.0 892.0; - 270.0 660.0 730.0 800.0 480.0] - @test dropdims(conv(x, w; pad=1), dims = (4,5)) == res - - @test dropdims(conv(Float32.(x), Float32.(w); pad=1), dims = (4,5)) == Float32.(res) - - @test dropdims(conv(x, w; dilation=2), dims = (3,4,5)) == [ - 608 788; - 644 824; - 680 860. - ] - - @test dropdims(conv(Float32.(x), Float32.(w); dilation=2), dims = (3,4,5)) == Float32.([ - 608 788; - 644 824; - 680 860. - ]) - # NaN tests for dilation forward pass - - ys = [] - for idx in 1:1000 - push!(ys, conv(x, w; dilation=2)) - end - @test !any([any(isnan.(ys[idx])) for idx in 1:1000]) - - # for gradients, check only size - # correctness of gradients is cross-checked with CUDNN.jl - # (it's assumed convolution code won't change often) - - @test size(∇conv_filter(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x; size=size(w))) == size(w) - @test size(∇conv_data(reshape(rand(4,3,2), 4, 3, 2, 1, 1), w; size=size(x))) == size(x) - - # NaN tests for dilation backward pass: filters - dy = randn(size(ys[1])) - dws = [] - for idx in 1:1000 - push!(dws, ∇conv_filter(dy, x; size=size(w), dilation=2)) - end - - # NaN tests for dilation backward pass: input - dxs = [] - for idx in 1:1000 - push!(dxs, ∇conv_data(dy, w; size=size(x), dilation=2)) - end - - @test !any([any(isnan.(dws[idx])) for idx in 1:1000]) - @test !any([any(isnan.(dxs[idx])) for idx in 1:1000]) - end -@testset "maxpool3d" begin - - x = reshape(Float64[1:60;], 5, 4, 3, 1, 1) - - @test dropdims(maxpool(x, (2,2,2)), dims = (3,4,5)) == [27 37; 29 39.] - @test dropdims(maxpool(x, (2,2,2); stride=(2,2,2)), dims = (3,4,5)) == [27 37; 29 39.] - res = zeros(3,3,2) - res[:, :, 1] = [ - 1.0 11.0 16.0; - 3.0 13.0 18.0; - 5.0 15.0 20.0] - res[:, :, 2] = [ - 41.0 51.0 56.0; - 43.0 53.0 58.0; - 45.0 55.0 60.0] - @test dropdims(maxpool(x, (2,2,2), pad=(1,1,1)), dims = (4,5)) == res - - # for gradients, check only size - # correctness of gradients is cross-checked with CUDNN.jl - # (it's assumed maxpooling code won't change often) - - y = maxpool(x, (2,2,2)) - dy = reshape(rand(2,2), 2, 2, 1, 1, 1) - @test size(∇maxpool(dy, y, x, (2,2,2))) == size(x) +@testset "Depthwise Convolution" begin + # Start with some easy-to-debug cases that we have worked through and _know_ work + for rank in (1,) #2,3) + @testset "depthwiseconv$(rank)d" begin + # Pull out known-good answers for y = depthwiseconv(x, w) + y_pad = conv_answer_dict[rank]["y_pad"] + y_dil = conv_answer_dict[rank]["y_dil"] + y_flip = conv_answer_dict[rank]["y_flip"] + + # We can always derive y_plain and y_stride from the other answers. + y_plain = y_pad[((2:(size(y_pad,idx)-1)) for idx in 1:rank)...] + y_stride = y_pad[((2:2:(size(y_pad,idx)-1)) for idx in 1:rank)...] + + # Same for dx and dw: + dx = conv_answer_dict[rank]["dx"] + dx_stride = conv_answer_dict[rank]["dx_stride"] + dx_pad = conv_answer_dict[rank]["dx_pad"] + dx_dil = conv_answer_dict[rank]["dx_dil"] + dx_flip = conv_answer_dict[rank]["dx_flip"] + + dw = conv_answer_dict[rank]["dw"] + dw_stride = conv_answer_dict[rank]["dw_stride"] + dw_pad = conv_answer_dict[rank]["dw_pad"] + dw_dil = conv_answer_dict[rank]["dw_dil"] + dw_flip = conv_answer_dict[rank]["dw_flip"] + + # We generate x and w from the shapes we know they must be + x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) + w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1) + + # A "drop channels and batch dimension" helper + ddims(x) = dropdims(x, dims=(rank+1, rank+2)) + + for conv in (NNlib.depthwiseconv, NNlib.depthwiseconv_im2col, NNlib.depthwiseconv_direct) + @testset "$(conv)" begin + # First, your basic convolution with no parameters + cdims = DepthwiseConvDims(x, w) + @test ddims(conv(x, w, cdims)) == y_plain + + # Next, test convolution on views and alternate datatypes: + @test ddims(conv(view(x, repeat([:], ndims(x))...), w, cdims)) == y_plain + @test ddims(conv(Float32.(x), Float32.(w), cdims)) == Float32.(y_plain) + + # Next, introduce stride: + cdims = DepthwiseConvDims(x, w; stride=2) + @test ddims(conv(x, w, cdims)) == y_stride + + # Next, introduce dilation: + cdims = DepthwiseConvDims(x, w; dilation=2) + @test ddims(conv(x, w, cdims)) == y_dil + + # Next, introduce padding: + cdims = DepthwiseConvDims(x, w; padding=1) + @test ddims(conv(x, w, cdims)) == y_pad + + # Next, test crosscor/conv with a flipped kernel + cdims = DepthwiseConvDims(x, w; flipkernel=true) + @test ddims(conv(x, w, cdims)) == y_flip + end + end + + # Test all implementations/interfaces + for (∇conv_filter, ∇conv_data) in ( + (NNlib.∇depthwiseconv_filter, NNlib.∇depthwiseconv_data), + (NNlib.∇depthwiseconv_filter_im2col, NNlib.∇depthwiseconv_data_im2col), + (NNlib.∇depthwiseconv_filter_direct, NNlib.∇depthwiseconv_data_direct), + ) + @testset "$(∇conv_filter)/$(∇conv_data)" begin + # First, your basic convolution with no parameters + cdims = DepthwiseConvDims(x, w) + dy = NNlib.depthwiseconv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw + @test ddims(∇conv_data(dy, w, cdims)) == dx + + # Next, test convolution on views and alternate datatypes: + @test ddims(∇conv_filter(x, view(dy, repeat([:], ndims(dy))...), cdims)) == dw + @test ddims(∇conv_data(view(dy, repeat([:], ndims(dy))...), w, cdims)) == dx + + @test ddims(∇conv_filter(Float32.(x), Float32.(dy), cdims)) == dw + @test ddims(∇conv_data(Float32.(dy), Float32.(w), cdims)) == dx + + # Next, introduce stride: + cdims = DepthwiseConvDims(x, w; stride=2) + dy = NNlib.depthwiseconv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_stride + @test ddims(∇conv_data(dy, w, cdims)) == dx_stride + + # Next, introduce dilation: + cdims = DepthwiseConvDims(x, w; dilation=2) + dy = NNlib.depthwiseconv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_dil + @test ddims(∇conv_data(dy, w, cdims)) == dx_dil + + # Next, introduce padding: + cdims = DepthwiseConvDims(x, w; padding=1) + dy = NNlib.depthwiseconv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_pad + @test ddims(∇conv_data(dy, w, cdims)) == dx_pad + + # Next, test crosscor/conv with a flipped kernel + cdims = DepthwiseConvDims(x, w; flipkernel=true) + dy = NNlib.depthwiseconv(x, w, cdims) + @test ddims(∇conv_filter(x, dy, cdims)) == dw_flip + @test ddims(∇conv_data(dy, w, cdims)) == dx_flip + end + end + end + end -end + @testset "fuzzing" begin + if get(ENV,"NNLIB_TEST_FUZZING","false") != "true" + @info("Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them") + return + end + @info("Starting Depthwise Convolutional fuzzing tests; this can take a few minutes...") + # Now that we're fairly certain things are working, let's fuzz things a little bit: + for x_size in ( + # 1d tests + (1,), (3,), (7,), + # 2d tests + (1, 3), (3, 3), (12, 3), (20, 17), + # 3d tests + (1, 1, 3), (3, 5, 4), (20, 17, 14), + ), + C_in in (1, 3), + batch in (1, 5) + + # Allocate x in this outer loop to save on allocations and speed things up + x = rand(x_size..., C_in, batch) + dx_direct = similar(x) + dx_im2col = similar(x) + + for w_size in ( + (1,), (3,), (7,), + (1,1), (1,3), (3,4), (7, 4), + (1,1,1), (1,1,3,), (3,4,3), (7,3,2)), + C_mult in (1, 4) + + # Give some output to the user that something is in fact happening. + print(".") + + # Allocate w in this outer loop to save on allocations and speed things up + w = rand(w_size..., C_mult, C_in) + dw_direct = similar(w) + dw_im2col = similar(w) + + for S_size in (1, 2, 4, (1,2), (4,1), (2,1,4)), + P_size in (0, 1, 2, (0,3,0,3), (4,1,4,2), (1,2,3,4,5,6)), + D_size in (1, 2, 4, (1,2), (3,2), (4,2,3)) + + # Skip tests that are impossible due to mismatched sizes + try + DepthwiseConvDims(x, w; + stride=S_size, padding=P_size, dilation=D_size, + ) + catch e + if isa(e, DimensionMismatch) || isa(e, MethodError) + continue + end + rethrow(e) + end + + # Do the actual convolution, comparing convolution implementations + cdims = DepthwiseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size) + + # We use mutating calls with explicitly different initial values, so as + # to be sure to catch when we're leaving pieces of the output untouched. + y_direct = ones(output_size(cdims)..., channels_out(cdims), batch) .* 666.666 + y_im2col = ones(output_size(cdims)..., channels_out(cdims), batch) .* 777.777 + + # Do the convolutions + NNlib.depthwiseconv_direct!(y_direct, x, w, cdims) + NNlib.depthwiseconv_im2col!(y_im2col, x, w, cdims) + + # Compare! + @test y_direct ≈ y_im2col + dy = y_im2col + + # Now push backwards; first for the filter. Again, we initialize our + # memory so that segments that never get touched are immediately noticable + fill!(dw_direct, 666.666) + fill!(dw_im2col, 777.777) + NNlib.∇depthwiseconv_filter_direct!(dw_direct, x, dy, cdims) + NNlib.∇depthwiseconv_filter_im2col!(dw_im2col, x, dy, cdims) + @test dw_direct ≈ dw_im2col + + # And then for the input + fill!(dx_direct, 666.666) + fill!(dx_im2col, 777.777) + NNlib.∇depthwiseconv_data_direct!(dx_direct, dy, w, cdims) + NNlib.∇depthwiseconv_data_im2col!(dx_im2col, dy, w, cdims) + @test dx_direct ≈ dx_im2col + end + end + end + println() + end +end \ No newline at end of file diff --git a/test/perf/.gitignore b/test/perf/.gitignore new file mode 100644 index 000000000..d39008407 --- /dev/null +++ b/test/perf/.gitignore @@ -0,0 +1 @@ +*.jld2 diff --git a/test/perf/Manifest.toml b/test/perf/Manifest.toml new file mode 100644 index 000000000..af3330b39 --- /dev/null +++ b/test/perf/Manifest.toml @@ -0,0 +1,524 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "8d59c3b1463b5e0ad05a3698167f85fac90e184d" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "0.3.2" + +[[Arpack]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Random", "SparseArrays", "Test"] +git-tree-sha1 = "1ce1ce9984683f0b6a587d5bdbc688ecb480096f" +uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" +version = "0.3.0" + +[[AxisAlgorithms]] +deps = ["Compat", "WoodburyMatrices"] +git-tree-sha1 = "99dabbe853e4f641ab21a676131f2cf9fb29937e" +uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" +version = "0.3.0" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BenchmarkTools]] +deps = ["JSON", "Printf", "Statistics", "Test"] +git-tree-sha1 = "5d1dd8577643ba9014574cd40d9c028cd5e4b85a" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "0.4.2" + +[[BinDeps]] +deps = ["Compat", "Libdl", "SHA", "URIParser"] +git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9" +uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" +version = "0.8.10" + +[[BinaryProvider]] +deps = ["Libdl", "Pkg", "SHA", "Test"] +git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.3" + +[[Calculus]] +deps = ["Compat"] +git-tree-sha1 = "f60954495a7afcee4136f78d1d60350abd37a409" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.4.1" + +[[Clustering]] +deps = ["Dates", "Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase", "Test"] +git-tree-sha1 = "c39b2cbf3ee27716f725e358bcb6952f3ac177b3" +uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" +version = "0.12.2" + +[[CodecZlib]] +deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] +git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.5.2" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random", "Test"] +git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.7.5" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"] +git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.9.5" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "1.5.1" + +[[Conda]] +deps = ["Compat", "JSON", "VersionParsing"] +git-tree-sha1 = "b625d802587c2150c279a40a646fba63f9bd8187" +uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" +version = "1.2.0" + +[[Contour]] +deps = ["LinearAlgebra", "StaticArrays", "Test"] +git-tree-sha1 = "b974e164358fea753ef853ce7bad97afec15bb80" +uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" +version = "0.5.1" + +[[Crayons]] +deps = ["Test"] +git-tree-sha1 = "3017c662a988bcb8a3f43306a793617c6524d476" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "1.0.0" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] +git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.15.0" + +[[DataValues]] +deps = ["Dates", "InteractiveUtils", "LinearAlgebra", "Random", "Test"] +git-tree-sha1 = "05e4a87fe52a2af1b4a1ffd3ab2fc996c038b192" +uuid = "e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5" +version = "0.4.7" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffEqDiffTools]] +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "4b21dd83c341412a0607334ac64bb5593a4bd583" +uuid = "01453d9d-ee7c-5054-8395-0335cb756afa" +version = "0.8.0" + +[[DiffResults]] +deps = ["Compat", "StaticArrays"] +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "0.0.4" + +[[DiffRules]] +deps = ["Random", "Test"] +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "0.0.10" + +[[Distances]] +deps = ["LinearAlgebra", "Printf", "Random", "Statistics", "Test"] +git-tree-sha1 = "a135c7c062023051953141da8437ed74f89d767a" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.8.0" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[Distributions]] +deps = ["Distributed", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "c24e9b6500c037673f0241a2783472b8c3d080c7" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.16.4" + +[[FFTW]] +deps = ["AbstractFFTs", "BinaryProvider", "Compat", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] +git-tree-sha1 = "29cda58afbf62f35b1a094882ad6c745a47b2eaa" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "0.2.4" + +[[FileIO]] +deps = ["Pkg", "Random", "Test"] +git-tree-sha1 = "c94b0787956629036fb2b20fccde9e52b89d079a" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.0.5" + +[[FixedPointNumbers]] +deps = ["Test"] +git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.5.3" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.3" + +[[GR]] +deps = ["Base64", "DelimitedFiles", "LinearAlgebra", "Pkg", "Printf", "Random", "Serialization", "Sockets", "Test"] +git-tree-sha1 = "41bd911efffb56957b45366770eaaa443de3f782" +uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" +version = "0.38.1" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[Interpolations]] +deps = ["AxisAlgorithms", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "SharedArrays", "SparseArrays", "StaticArrays", "Test", "WoodburyMatrices"] +git-tree-sha1 = "e8d1c381b1dc5343e5b6d37265acbe1de493d512" +uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +version = "0.11.2" + +[[IterableTables]] +deps = ["DataValues", "IteratorInterfaceExtensions", "Requires", "TableTraits", "TableTraitsUtils", "Test"] +git-tree-sha1 = "0eec91e8185899f3926f56db515559bfe95b9db7" +uuid = "1c8ee90f-4401-5389-894e-7a04a3dc0f4d" +version = "0.10.0" + +[[IteratorInterfaceExtensions]] +deps = ["Test"] +git-tree-sha1 = "5484e5ede2a4137b9643f4d646e8e7b87b794415" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "0.1.1" + +[[JLD2]] +deps = ["CodecZlib", "DataStructures", "FileIO", "LinearAlgebra", "Mmap", "Printf", "Random", "Test"] +git-tree-sha1 = "3ba90ff93e1d5b9b2103588051c2d349fae54dac" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.1.2" + +[[JSON]] +deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"] +git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.20.0" + +[[KernelDensity]] +deps = ["Distributions", "FFTW", "Interpolations", "Optim", "StatsBase", "Test"] +git-tree-sha1 = "c1048817fe5711f699abc8fabd47b1ac6ba4db04" +uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" +version = "0.5.1" + +[[LibGit2]] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LineSearches]] +deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf", "Test"] +git-tree-sha1 = "54eb90e8dbe745d617c78dee1d6ae95c7f6f5779" +uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +version = "7.0.1" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Measures]] +deps = ["Test"] +git-tree-sha1 = "ddfd6d13e330beacdde2c80de27c1c671945e7d9" +uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" +version = "0.3.0" + +[[Missings]] +deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] +git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.0" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[NLSolversBase]] +deps = ["Calculus", "DiffEqDiffTools", "DiffResults", "Distributed", "ForwardDiff", "LinearAlgebra", "Random", "SparseArrays", "Test"] +git-tree-sha1 = "0c6f0e7f2178f78239cfb75310359eed10f2cacb" +uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" +version = "7.3.1" + +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "Test", "TimerOutputs"] +path = "../.." +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.4.3+" + +[[NaNMath]] +deps = ["Compat"] +git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.2" + +[[NearestNeighbors]] +deps = ["Distances", "LinearAlgebra", "Mmap", "StaticArrays", "Test"] +git-tree-sha1 = "f47c5d97cf9a8caefa47e9fa9d99d8fda1a65154" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.3" + +[[Observables]] +deps = ["Test"] +git-tree-sha1 = "dc02cec22747d1d10d9f70d8a1c03432b5bfbcd0" +uuid = "510215fc-4207-5dde-b226-833fc4488ee2" +version = "0.2.3" + +[[OffsetArrays]] +deps = ["DelimitedFiles", "Test"] +git-tree-sha1 = "e6893807f09c1d5517861ded8b203cb96cb7d44a" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "0.10.0" + +[[Optim]] +deps = ["Calculus", "DiffEqDiffTools", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "Random", "SparseArrays", "StatsBase", "Test"] +git-tree-sha1 = "0f2a6c6ff9db396cc7af15bb1cf057a26662ff17" +uuid = "429524aa-4258-5aef-a3af-852621145aeb" +version = "0.17.2" + +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.0.2" + +[[PDMats]] +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] +git-tree-sha1 = "b6c91fc0ab970c0563cbbe69af18d741a49ce551" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.9.6" + +[[Parameters]] +deps = ["Markdown", "OrderedCollections", "REPL", "Test"] +git-tree-sha1 = "70bdbfb2bceabb15345c0b54be4544813b3444e4" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.10.3" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[PlotThemes]] +deps = ["PlotUtils", "Requires", "Test"] +git-tree-sha1 = "f3afd2d58e1f6ac9be2cea46e4a9083ccc1d990b" +uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" +version = "0.3.0" + +[[PlotUtils]] +deps = ["Colors", "Dates", "Printf", "Random", "Reexport", "Test"] +git-tree-sha1 = "fd28f30a294a38ec847de95d8ac7ac916ccd7c06" +uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" +version = "0.5.5" + +[[Plots]] +deps = ["Base64", "Contour", "Dates", "FixedPointNumbers", "GR", "JSON", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "Reexport", "Requires", "Showoff", "SparseArrays", "StaticArrays", "Statistics", "StatsBase", "Test", "UUIDs"] +git-tree-sha1 = "c68a9ec8a13a5bdcb85c311378a86b7d7b9b0792" +uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +version = "0.23.1" + +[[PositiveFactorizations]] +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "86ae7329c4b5c266acf5c7c524a972300d991e1c" +uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" +version = "0.2.1" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[QuadGK]] +deps = ["DataStructures", "LinearAlgebra", "Test"] +git-tree-sha1 = "3ce467a8e76c6030d4c3786e7d3a73442017cdc0" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.0.3" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Ratios]] +deps = ["Compat"] +git-tree-sha1 = "fd159bead0a24e6270fd0573a340312bd4645cc2" +uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" +version = "0.3.0" + +[[RecipesBase]] +deps = ["Random", "Test"] +git-tree-sha1 = "0b3cb370ee4dc00f47f1193101600949f3dcf884" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "0.6.0" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["Test"] +git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "0.5.2" + +[[Rmath]] +deps = ["BinaryProvider", "Libdl", "Random", "Statistics", "Test"] +git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.5.0" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Showoff]] +deps = ["Compat"] +git-tree-sha1 = "276b24f3ace98bec911be7ff2928d497dc759085" +uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" +version = "0.2.1" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] +git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.7.2" + +[[StaticArrays]] +deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] +git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.10.3" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsBase]] +deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] +git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.29.0" + +[[StatsFuns]] +deps = ["Rmath", "SpecialFunctions", "Test"] +git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.8.0" + +[[StatsPlots]] +deps = ["Clustering", "DataStructures", "DataValues", "Distributions", "IterableTables", "KernelDensity", "Observables", "Plots", "RecipesBase", "Reexport", "StatsBase", "TableTraits", "TableTraitsUtils", "Test", "Widgets"] +git-tree-sha1 = "d722a2d4293ded61124654aae6696c68d7946a95" +uuid = "f3b207a7-027a-5e70-b257-86293d7955fd" +version = "0.10.2" + +[[SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[TableTraits]] +deps = ["IteratorInterfaceExtensions", "Test"] +git-tree-sha1 = "eba4b1d0a82bdd773307d652c6e5f8c82104c676" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "0.4.1" + +[[TableTraitsUtils]] +deps = ["DataValues", "IteratorInterfaceExtensions", "Missings", "TableTraits", "Test"] +git-tree-sha1 = "55133a5476b61ec31060e555ffe12da27ac13682" +uuid = "382cd787-c1b6-5bf2-a167-d5b971a19bda" +version = "0.4.0" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TimerOutputs]] +deps = ["Crayons", "Printf", "Test", "Unicode"] +git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.0" + +[[TranscodingStreams]] +deps = ["Pkg", "Random", "Test"] +git-tree-sha1 = "8a032ceb5cf7a28bf1bdb77746b250b9e9fda565" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.0" + +[[URIParser]] +deps = ["Test", "Unicode"] +git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" +uuid = "30578b45-9adc-5946-b283-645ec420af67" +version = "0.4.0" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[VersionParsing]] +deps = ["Compat"] +git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669" +uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" +version = "1.1.3" + +[[Widgets]] +deps = ["Colors", "Dates", "Observables", "OrderedCollections", "Test"] +git-tree-sha1 = "f48ee34d9495924aba50eeb328d83b0034b787f5" +uuid = "cc8bc4a8-27d6-5769-a93b-9d913e69aa62" +version = "0.5.0" + +[[WoodburyMatrices]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Test"] +git-tree-sha1 = "21772c33b447757ec7d3e61fcdfb9ea5c47eedcf" +uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" +version = "0.4.1" diff --git a/test/perf/Project.toml b/test/perf/Project.toml new file mode 100644 index 000000000..3a54aa5de --- /dev/null +++ b/test/perf/Project.toml @@ -0,0 +1,7 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" diff --git a/test/perf/perf_report.jl b/test/perf/perf_report.jl new file mode 100644 index 000000000..b9c193753 --- /dev/null +++ b/test/perf/perf_report.jl @@ -0,0 +1,99 @@ +using JLD2, NNlib, BenchmarkTools + +# We need things to go quickly here +BenchmarkTools.DEFAULT_PARAMETERS.samples = 20 +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 2.5 + +results = Dict() + +function add_result(val, keys...) + r = results + for k in keys[1:end-1] + if !haskey(r, k) + r[k] = Dict() + end + r = r[k] + end + r[keys[end]] = val + return r +end + +# Modify these as needed +for rank in (2,), + N in (20, 40, 80), + C_in in (1,), + C_out in (1,), + K in (3,), + stride in (1,), + dilation in (1,), + padding in (0, 2) + + for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in ( + (NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, "direct"), + (NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, DenseConvDims, "im2col"), + (NNlib.depthwiseconv_direct!, NNlib.∇depthwiseconv_data_direct!, NNlib.∇depthwiseconv_filter_direct!, DepthwiseConvDims, "direct"), + (NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"), + ) + + x = zeros(Float32, repeat([N], rank)..., C_in, 1) + if cT == DenseConvDims + w = zeros(Float32, repeat([K], rank)..., C_in, C_out) + else + w = zeros(Float32, repeat([K], rank)..., C_out, C_in) + end + cdims = try + cT(x, w; stride=stride, dilation=dilation, padding=padding) + catch + continue + end + + if cT == DenseConvDims + y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1) + else + y = zeros(Float32, NNlib.output_size(cdims)..., C_out*C_in, 1) + end + + dx = similar(x) + dw = similar(w) + dy = similar(y) + + t_fwd = @benchmark $(conv!)($y, $x, $w, $cdims) + t_dx = @benchmark $(∇conv_data!)($dx, $y, $w, $cdims) + t_dw = @benchmark $(∇conv_filter!)($dw, $x, $y, $cdims) + + add_result(t_fwd, "conv$(rank)d", backend, cdims) + add_result(t_dx, "conv$(rank)d_data", backend, cdims) + add_result(t_dw, "conv$(rank)d_filter", backend, cdims) + + @show(cdims) + @save "results.jld2" results + end +end + + +# Modify these as needed +for rank in (2,), + N in (20,), + K in (2, 4), + stride in (1, 2, 4) + + x = zeros(Float32, repeat([N], rank)..., 1, 1) + pdims = PoolDims(x, K; stride=stride) + y = zeros(Float32, NNlib.output_size(pdims)..., 1, 1) + dx = similar(x) + + for (pool, ∇pool, name) in ( + (NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"), + (NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"), + ) + + t_fwd = @benchmark $(pool)( $y, $x, $pdims) + t_data = @benchmark $(∇pool)($dx, $y, $y, $x, $pdims) + + add_result(t_fwd, "$(name)$(rank)d", "direct", pdims) + add_result(t_data, "$(name)$(rank)d_data", "direct", pdims) + + @show(pdims) + @save "results.jld2" results + end +end diff --git a/test/pooling.jl b/test/pooling.jl new file mode 100644 index 000000000..7b9808684 --- /dev/null +++ b/test/pooling.jl @@ -0,0 +1,299 @@ +using NNlib, Test + +maxpool_answer_dict = Dict( + 1 => Dict( + "y" => [2, 4.], + "y_nostride" => [2, 3, 4, 5.], + "y_pad" => [1, 3, 5.], + + "dx" => [0, 2, 0, 4, 0.], + "dx_nostride" => [0, 2, 3, 4, 5.], + "dx_pad" => [1, 0, 3, 0, 5.], + ), + 2 => Dict( + "y" => [ + 7 17.; + 9 19. + ], + "y_nostride" => [ + 7 12 17; + 8 13 18; + 9 14 19; + 10 15 20. + ], + "y_pad" => [ + 1 11 16; + 3 13 18; + 5 15 20. + ], + + "dx" => [ + 0 0 0 0; + 0 7 0 17; + 0 0 0 0; + 0 9 0 19; + 0 0 0 0. + ], + "dx_nostride" => [ + 0 0 0 0; + 0 7 12 17; + 0 8 13 18; + 0 9 14 19; + 0 10 15 20. + ], + "dx_pad" => [ + 1 0 11 16; + 0 0 0 0; + 3 0 13 18; + 0 0 0 0; + 5 0 15 20. + ], + ), + 3 => Dict( + "y" => reshape([ + 27, 29, + 37, 39. + ], (2, 2, 1)), + "y_nostride" => reshape([ + 27, 28, 29, 30, + 32, 33, 34, 35, + 37, 38, 39, 40, + + 47, 48, 49, 50, + 52, 53, 54, 55, + 57, 58, 59, 60. + ], (4, 3, 2)), + "y_pad" => reshape([ + 1, 3, 5, + 11, 13, 15, + 16, 18, 20, + + 41, 43, 45, + 51, 53, 55, + 56, 58, 60. + ], (3, 3, 2)), + + "dx" => reshape([ + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, + 0, 27, 0, 29, 0, + 0, 0, 0, 0, 0, + 0, 37, 0, 39, 0, + + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0. + ], (5, 4, 3)), + "dx_nostride" => reshape([ + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, + 0, 27, 28, 29, 30, + 0, 32, 33, 34, 35, + 0, 37, 38, 39, 40, + + 0, 0, 0, 0, 0, + 0, 47, 48, 49, 50, + 0, 52, 53, 54, 55, + 0, 57, 58, 59, 60. + ], (5, 4, 3)), + "dx_pad" => reshape([ + 1, 0, 3, 0, 5, + 0, 0, 0, 0, 0, + 11, 0, 13, 0, 15, + 16, 0, 18, 0, 20, + + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + + 41, 0, 43, 0, 45, + 0, 0, 0, 0, 0, + 51, 0, 53, 0, 55, + 56, 0, 58, 0, 60. + ], (5, 4, 3)), + ) +) + +meanpool_answer_dict = Dict( + 1 => Dict( + "y" => [1.5, 3.5], + "y_nostride" => [1.5, 2.5, 3.5, 4.5], + "y_pad" => [0.5, 2.5, 4.5], + + "dx" => [0.75, 0.75, 1.75, 1.75, 0.0], + "dx_nostride" => [0.75, 2.0, 3.0, 4.0, 2.25], + "dx_pad" => [0.25, 1.25, 1.25, 2.25, 2.25], + ), + 2 => Dict( + "y" => [ + 4.0 14.0; + 6.0 16.0 + ], + "y_nostride" => [ + 4.0 9.0 14.0 + 5.0 10.0 15.0 + 6.0 11.0 16.0 + 7.0 12.0 17.0 + ], + "y_pad" => [ + 0.25 4.25 4.0 + 1.25 10.0 8.75 + 2.25 12.0 9.75 + ], + + "dx" => [ + 1.0 1.0 3.5 3.5; + 1.0 1.0 3.5 3.5; + 1.5 1.5 4.0 4.0; + 1.5 1.5 4.0 4.0; + 0.0 0.0 0.0 0.0 + ], + "dx_nostride" => [ + 1.0 3.25 5.75 3.5; + 2.25 7.0 12.0 7.25; + 2.75 8.0 13.0 7.75; + 3.25 9.0 14.0 8.25; + 1.75 4.75 7.25 4.25 + ], + "dx_pad" => [ + 0.0625 1.0625 1.0625 1.0; + 0.3125 2.5 2.5 2.1875; + 0.3125 2.5 2.5 2.1875; + 0.5625 3.0 3.0 2.4375; + 0.5625 3.0 3.0 2.4375 + ], + ), + 3 => Dict( + "y" => reshape([ + 14.0, 16.0, + 24.0, 26.0 + ], (2, 2, 1)), + "y_nostride" => reshape([ + 14.0, 15.0, 16.0, 17.0, + 19.0, 20.0, 21.0, 22.0, + 24.0, 25.0, 26.0, 27.0, + + 34.0, 35.0, 36.0, 37.0, + 39.0, 40.0, 41.0, 42.0, + 44.0, 45.0, 46.0, 47.0 + ], (4, 3, 2)), + "y_pad" => reshape([ + 0.125, 0.625, 1.125, + 2.125, 5.0, 6.0, + 2.0, 4.375, 4.875, + + 7.75, 16.25, 17.25, + 19.25, 40.0, 42.0, + 11.5, 23.75, 24.75, + ], (3, 3, 2)), + + "dx" => reshape([ + 1.75, 1.75, 2.0, 2.0, 0.0, + 1.75, 1.75, 2.0, 2.0, 0.0, + 3.0, 3.0, 3.25, 3.25, 0.0, + 3.0, 3.0, 3.25, 3.25, 0.0, + + 1.75, 1.75, 2.0, 2.0, 0.0, + 1.75, 1.75, 2.0, 2.0, 0.0, + 3.0, 3.0, 3.25, 3.25, 0.0, + 3.0, 3.0, 3.25, 3.25, 0.0, + + 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, + ], (5, 4, 3)), + "dx_nostride" => reshape([ + 1.75, 3.625, 3.875, 4.125, 2.125, + 4.125, 8.5, 9.0, 9.5, 4.875, + 5.375, 11.0, 11.5, 12.0, 6.125, + 3.0, 6.125, 6.375, 6.625, 3.375, + + 6.0, 12.25, 12.75, 13.25, 6.75, + 13.25, 27.0, 28.0, 29.0, 14.75, + 15.75, 32.0, 33.0, 34.0, 17.25, + 8.5, 17.25, 17.75, 18.25, 9.25, + + 4.25, 8.625, 8.875, 9.125, 4.625, + 9.125, 18.5, 19.0, 19.5, 9.875, + 10.375, 21.0, 21.5, 22.0, 11.125, + 5.5, 11.125, 11.375, 11.625, 5.875 + ], (5, 4, 3)), + "dx_pad" => reshape([ + 0.015625, 0.078125, 0.078125, 0.140625, 0.140625, + 0.265625, 0.625, 0.625, 0.75, 0.75, + 0.265625, 0.625, 0.625, 0.75, 0.75, + 0.25, 0.546875, 0.546875, 0.609375, 0.609375, + + 0.96875, 2.03125, 2.03125, 2.15625, 2.15625, + 2.40625, 5.0, 5.0, 5.25, 5.25, + 2.40625, 5.0, 5.0, 5.25, 5.25, + 1.4375, 2.96875, 2.96875, 3.09375, 3.09375, + + 0.96875, 2.03125, 2.03125, 2.15625, 2.15625, + 2.40625, 5.0, 5.0, 5.25, 5.25, + 2.40625, 5.0, 5.0, 5.25, 5.25, + 1.4375, 2.96875, 2.96875, 3.09375, 3.09375 + ], (5, 4, 3)), + ) +) + +for rank in (1, 2, 3) + @testset "pool$(rank)d" begin + for (pool, ∇pool, answer_dict) in ( + # Main API name + (maxpool, ∇maxpool, maxpool_answer_dict), + (meanpool, ∇meanpool, meanpool_answer_dict), + + # _direct name + (NNlib.maxpool_direct, NNlib.∇maxpool_direct, maxpool_answer_dict), + (NNlib.meanpool_direct, NNlib.∇meanpool_direct, meanpool_answer_dict), + ) + + @testset "$(pool)$(rank)d" begin + y = answer_dict[rank]["y"] + y_nostride = answer_dict[rank]["y_nostride"] + y_pad = answer_dict[rank]["y_pad"] + dx = answer_dict[rank]["dx"] + dx_nostride = answer_dict[rank]["dx_nostride"] + dx_pad = answer_dict[rank]["dx_pad"] + + x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) + + # A "drop channels and batch dimension" helper + ddims(x) = dropdims(x, dims=(rank+1, rank+2)) + + # Let's ensure that a 1x1x1 pooling kernel always just returns `x` + @test pool(x, PoolDims(x, 1)) == x + + # Test vanilla pooling + pdims = PoolDims(x, 2) + y_hat = pool(x, pdims) + @test ddims(y_hat) == y + @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx + + # Strided pooling + pdims = PoolDims(x, 2; stride=1) + y_hat = pool(x, pdims) + @test ddims(y_hat) == y_nostride + @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_nostride + + # Padded pooling + pdims = PoolDims(x, 2; padding=1) + y_hat = pool(x, pdims) + @test ddims(y_hat) == y_pad + @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_pad + end + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 901c1e2c1..0e20baea7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,39 +1,5 @@ using NNlib, Test -@testset "NNlib" begin - include("activation.jl") include("conv.jl") - -xs = [-100_000, -100_000.] -@test softmax(xs) ≈ [0.5, 0.5] -@test logsoftmax(xs) ≈ log.([0.5, 0.5]) - -xs = rand(5) -@test softmax(xs) ≈ exp.(xs) ./ sum(exp.(xs)) -@test logsoftmax(xs) ≈ log.(softmax(xs)) -@test logsigmoid.(xs) ≈ log.(sigmoid.(xs)) - -xs = rand(5,10) -@test softmax(xs) ≈ exp.(xs) ./ sum(exp.(xs), dims = 1) -@test logsoftmax(xs) ≈ log.(softmax(xs)) -@test logsigmoid.(xs) ≈ log.(sigmoid.(xs)) - -for T in [:Float32, :Float64] - @eval @test logsigmoid.($T[-100_000, 100_000.]) ≈ $T[-100_000, 0.] -end - -## compare the outputs with the PyTorch nn.LogSoftmax returns -xs = Float32[1, 2, 3000.] -@test logsoftmax(xs) ≈ [-2999, -2998, 0] - -xs = Float32[1 2 3; 1000 2000 3000] -@test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.] - -@test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs)) -@test NNlib.∇softmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs)) - -xs = [-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842; 0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663; -1.14637 -0.577988 0.718952 0.91972 -0.620773 0.929977] -@test isapprox(NNlib.∇logsoftmax(ones(size(xs)), xs), [0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273; -0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428; 0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697]; rtol = 1e-6) -@test isapprox(NNlib.∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6) -end +include("pooling.jl") \ No newline at end of file