diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index acee725c4..70827d0aa 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -17,11 +17,11 @@ Creates a ConvMixer model. function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000) stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1]) - blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation; - preact = true, groups = planes, pad = SamePad())...), +), - conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] + blocks = [Chain(SkipConnection(conv_bn(kernel_size, planes, planes, activation; + preact = true, groups = planes, pad = SamePad()), +), + conv_bn((1, 1), planes, planes, activation; preact = true)) for _ in 1:depth] head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses)) - return Chain(Chain(stem..., blocks...), head) + return Chain(Chain(stem, Chain(blocks)), head) end convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9), diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 9d866f3d1..1621803bf 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -10,7 +10,7 @@ Creates a single block of ConvNeXt. - `λ`: Init value for LayerScale """ function convnextblock(planes, drop_path_rate = 0., λ = 1f-6) - layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3), + layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3), swapdims((3, 1, 2, 4)), LayerNorm(planes; ϵ = 1f-6), mlp_block(planes, 4 * planes), @@ -61,7 +61,7 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6 LayerNorm(planes[end]), Dense(planes[end], nclasses)) - return Chain(Chain(backbone...), head) + return Chain(Chain(backbone), head) end # Configurations for ConvNeXt models diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 0fc3980b5..eff19f1a8 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -11,10 +11,10 @@ Create a Densenet bottleneck layer """ function dense_bottleneck(inplanes, outplanes) inner_channels = 4 * outplanes - m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true)..., - conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true)...) + m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true), + conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true)) - SkipConnection(m, (mx, x) -> cat(x, mx; dims = 3)) + SkipConnection(m, cat_channels) end """ @@ -28,8 +28,7 @@ Create a DenseNet transition sequence - `outplanes`: number of output feature maps """ transition(inplanes, outplanes) = - [conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., - MeanPool((2, 2))] + Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true), MeanPool((2, 2))) """ dense_block(inplanes, growth_rates) @@ -60,20 +59,21 @@ Create a DenseNet model - `nclasses`: the number of output classes """ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000) - layers = conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false) + layers = [] + push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)) push!(layers, MaxPool((3, 3), stride = 2, pad = (1, 1))) outplanes = 0 for (i, rates) in enumerate(growth_rates) outplanes = inplanes + sum(rates) append!(layers, dense_block(inplanes, rates)) - (i != length(growth_rates)) && - append!(layers, transition(outplanes, floor(Int, outplanes * reduction))) + (i != length(growth_rates)) && + push!(layers, transition(outplanes, floor(Int, outplanes * reduction))) inplanes = floor(Int, outplanes * reduction) end push!(layers, BatchNorm(outplanes, relu)) - return Chain(Chain(layers...), + return Chain(Chain(layers), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses))) diff --git a/src/convnets/googlenet.jl b/src/convnets/googlenet.jl index 4de47a0ef..bc42a052f 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/googlenet.jl @@ -15,16 +15,16 @@ Create an inception module for use in GoogLeNet """ function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj) branch1 = Chain(Conv((1, 1), inplanes => out_1x1)) - + branch2 = Chain(Conv((1, 1), inplanes => red_3x3), Conv((3, 3), red_3x3 => out_3x3; pad = 1)) - + branch3 = Chain(Conv((1, 1), inplanes => red_5x5), Conv((5, 5), red_5x5 => out_5x5; pad = 2)) - + branch4 = Chain(MaxPool((3, 3), stride=1, pad = 1), Conv((1, 1), inplanes => pool_proj)) - + return Parallel(cat_channels, branch1, branch2, branch3, branch4) end diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index a9a33ed50..00bdd0ccb 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -9,17 +9,17 @@ Create an Inception-v3 style-A module - `pool_proj`: the number of output feature maps for the pooling projection """ function inception_a(inplanes, pool_proj) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 64)...) - - branch5x5 = Chain(conv_bn((1, 1), inplanes, 48)..., - conv_bn((5, 5), 48, 64; pad = 2)...) + branch1x1 = conv_bn((1, 1), inplanes, 64) - branch3x3 = Chain(conv_bn((1, 1), inplanes, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; pad = 1)...) + branch5x5 = Chain(conv_bn((1, 1), inplanes, 48), + conv_bn((5, 5), 48, 64; pad = 2)) + + branch3x3 = Chain(conv_bn((1, 1), inplanes, 64), + conv_bn((3, 3), 64, 96; pad = 1), + conv_bn((3, 3), 96, 96; pad = 1)) branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1), - conv_bn((1, 1), inplanes, pool_proj)...) + conv_bn((1, 1), inplanes, pool_proj)) return Parallel(cat_channels, branch1x1, branch5x5, branch3x3, branch_pool) @@ -35,13 +35,13 @@ Create an Inception-v3 style-B module - `inplanes`: number of input feature maps """ function inception_b(inplanes) - branch3x3_1 = Chain(conv_bn((3, 3), inplanes, 384; stride = 2)...) + branch3x3_1 = conv_bn((3, 3), inplanes, 384; stride = 2) - branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64)..., - conv_bn((3, 3), 64, 96; pad = 1)..., - conv_bn((3, 3), 96, 96; stride = 2)...) + branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64), + conv_bn((3, 3), 64, 96; pad = 1), + conv_bn((3, 3), 96, 96; stride = 2)) - branch_pool = Chain(MaxPool((3, 3), stride = 2)) + branch_pool = MaxPool((3, 3), stride = 2) return Parallel(cat_channels, branch3x3_1, branch3x3_2, branch_pool) @@ -59,20 +59,20 @@ Create an Inception-v3 style-C module - `n`: the "grid size" (kernel size) for the convolution layers """ function inception_c(inplanes, inner_planes, n = 7) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 192)...) + branch1x1 = conv_bn((1, 1), inplanes, 192) - branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes)..., - conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...) + branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes), + conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)), + conv_bn((n, 1), inner_planes, 192; pad = (3, 0))) - branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes)..., - conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))..., - conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))..., - conv_bn((1, n), inner_planes, 192; pad = (0, 3))...) + branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes), + conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)), + conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)), + conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)), + conv_bn((1, n), inner_planes, 192; pad = (0, 3))) - branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1), - conv_bn((1, 1), inplanes, 192)...) + branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1), + conv_bn((1, 1), inplanes, 192)) return Parallel(cat_channels, branch1x1, branch7x7_1, branch7x7_2, branch_pool) @@ -88,15 +88,15 @@ Create an Inception-v3 style-D module - `inplanes`: number of input feature maps """ function inception_d(inplanes) - branch3x3 = Chain(conv_bn((1, 1), inplanes, 192)..., - conv_bn((3, 3), 192, 320; stride = 2)...) + branch3x3 = Chain(conv_bn((1, 1), inplanes, 192), + conv_bn((3, 3), 192, 320; stride = 2)) - branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192)..., - conv_bn((1, 7), 192, 192; pad = (0, 3))..., - conv_bn((7, 1), 192, 192; pad = (3, 0))..., - conv_bn((3, 3), 192, 192; stride = 2)...) + branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192), + conv_bn((1, 7), 192, 192; pad = (0, 3)), + conv_bn((7, 1), 192, 192; pad = (3, 0)), + conv_bn((3, 3), 192, 192; stride = 2)) - branch_pool = Chain(MaxPool((3, 3), stride=2)) + branch_pool = MaxPool((3, 3), stride=2) return Parallel(cat_channels, branch3x3, branch7x7x3, branch_pool) @@ -112,26 +112,26 @@ Create an Inception-v3 style-E module - `inplanes`: number of input feature maps """ function inception_e(inplanes) - branch1x1 = Chain(conv_bn((1, 1), inplanes, 320)...) + branch1x1 = conv_bn((1, 1), inplanes, 320) - branch3x3_1 = Chain(conv_bn((1, 1), inplanes, 384)...) - branch3x3_1a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...) - branch3x3_1b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_1 = conv_bn((1, 1), inplanes, 384) + branch3x3_1a = conv_bn((1, 3), 384, 384; pad = (0, 1)) + branch3x3_1b = conv_bn((3, 1), 384, 384; pad = (1, 0)) - branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448)..., - conv_bn((3, 3), 448, 384; pad = 1)...) - branch3x3_2a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...) - branch3x3_2b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...) + branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448), + conv_bn((3, 3), 448, 384; pad = 1)) + branch3x3_2a = conv_bn((1, 3), 384, 384; pad = (0, 1)) + branch3x3_2b = conv_bn((3, 1), 384, 384; pad = (1, 0)) branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1), - conv_bn((1, 1), inplanes, 192)...) + conv_bn((1, 1), inplanes, 192)) return Parallel(cat_channels, branch1x1, Chain(branch3x3_1, Parallel(cat_channels, branch3x3_1a, branch3x3_1b)), - + Chain(branch3x3_2, Parallel(cat_channels, branch3x3_2a, branch3x3_2b)), @@ -150,12 +150,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). `inception3` does not currently support pretrained weights. """ function inception3(; nclasses = 1000) - layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)..., - conv_bn((3, 3), 32, 32)..., - conv_bn((3, 3), 32, 64; pad = 1)..., + layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2), + conv_bn((3, 3), 32, 32), + conv_bn((3, 3), 32, 64; pad = 1), MaxPool((3, 3), stride = 2), - conv_bn((1, 1), 64, 80)..., - conv_bn((3, 3), 80, 192)..., + conv_bn((1, 1), 64, 80), + conv_bn((3, 3), 80, 192), MaxPool((3, 3), stride = 2), inception_a(192, 32), inception_a(256, 64), diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 69d448e68..186726ef9 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -31,17 +31,15 @@ function mobilenetv1(width_mult, config; for (dw, outch, stride, repeats) in config outch = Int(outch * width_mult) for _ in 1:repeats - layer = if dw - depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1) - else - conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1) - end - append!(layers, layer) + layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; + stride = stride, pad = 1) : + conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1) + push!(layers, layer) inchannels = outch end end - return Chain(Chain(layers...), + return Chain(Chain(layers), Chain(GlobalMeanPool(), MLUtils.flatten, Dense(inchannels, fcsize, activation), @@ -120,7 +118,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000) # building first layer inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8) layers = [] - append!(layers, conv_bn((3, 3), 3, inplanes, stride = 2)) + push!(layers, conv_bn((3, 3), 3, inplanes, stride = 2)) # building inverted residual blocks for (t, c, n, s, a) in configs @@ -136,8 +134,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000) outplanes = (width_mult > 1) ? _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) : max_width - return Chain(Chain(layers..., - conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)...), + return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses))) end @@ -186,7 +183,7 @@ end (m::MobileNetv2)(x) = m.layers(x) backbone(m::MobileNetv2) = m.layers[1] -classifier(m::MobileNetv2) = m.layers[2:end] +classifier(m::MobileNetv2) = m.layers[2] # MobileNetv3 @@ -214,7 +211,7 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) layers = [] - append!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2)) + push!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2)) explanes = 0 # building inverted residual blocks for (k, t, c, r, a, s) in configs @@ -229,13 +226,12 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000) # building last several layers output_channel = max_width output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : output_channel - classifier = (Dense(explanes, output_channel, hardswish), - Dropout(0.2), - Dense(output_channel, nclasses)) + classifier = Chain(Dense(explanes, output_channel, hardswish), + Dropout(0.2), + Dense(output_channel, nclasses)) - return Chain(Chain(layers..., - conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...), - Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier...)) + return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) end # Configurations for small and large mode for MobileNetv3 @@ -310,4 +306,4 @@ end (m::MobileNetv3)(x) = m.layers(x) backbone(m::MobileNetv3) = m.layers[1] -classifier(m::MobileNetv3) = m.layers[2:end] +classifier(m::MobileNetv3) = m.layers[2] diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl index 97c689432..5de0e35dd 100644 --- a/src/convnets/resnet.jl +++ b/src/convnets/resnet.jl @@ -12,8 +12,8 @@ Create a basic residual block """ function basicblock(inplanes, outplanes, downsample = false) stride = downsample ? 2 : 1 - Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false)...) + Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false), + conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false)) end """ @@ -36,9 +36,9 @@ The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead. """ function bottleneck(inplanes, outplanes, downsample = false; stride = [1, (downsample ? 2 : 1), 1]) - Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false)..., - conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false)..., - conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false)...) + Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false), + conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false), + conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false)) end @@ -82,7 +82,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection = inplanes = 64 baseplanes = 64 layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false)) + push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false)) push!(layers, MaxPool((3, 3), stride = (2, 2), pad = (1, 1))) for (i, nrepeats) in enumerate(block_config) # output planes within a block @@ -102,7 +102,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection = baseplanes *= 2 end - return Chain(Chain(layers...), + return Chain(Chain(layers), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(inplanes, nclasses))) end @@ -246,7 +246,7 @@ function ResNet(depth::Int = 50; pretrain = false, nclasses = 1000) model end -# Compat with Methalhead 0.6; remove in 0.7 +# Compat with Metalhead 0.6; remove in 0.7 @deprecate ResNet18(; kw...) ResNet(18; kw...) @deprecate ResNet34(; kw...) ResNet(34; kw...) @deprecate ResNet50(; kw...) ResNet(50; kw...) diff --git a/src/convnets/resnext.jl b/src/convnets/resnext.jl index e3ebd21b3..c9d7aa669 100644 --- a/src/convnets/resnext.jl +++ b/src/convnets/resnext.jl @@ -13,19 +13,17 @@ Create a basic residual block as defined in the paper for ResNeXt """ function resnextblock(inplanes, outplanes, cardinality, width, downsample = false) stride = downsample ? 2 : 1 - hidden_channels = cardinality * width - - return Chain(conv_bn((1, 1), inplanes, hidden_channels; stride = 1, bias = false)..., + return Chain(conv_bn((1, 1), inplanes, hidden_channels; stride = 1, bias = false), conv_bn((3, 3), hidden_channels, hidden_channels; - stride = stride, pad = 1, bias = false, groups = cardinality)..., - conv_bn((1, 1), hidden_channels, outplanes; stride = 1, bias = false)...) + stride = stride, pad = 1, bias = false, groups = cardinality), + conv_bn((1, 1), hidden_channels, outplanes; stride = 1, bias = false)) end """ resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @. relu(x) + relu(y); block_config, nclasses = 1000) - + Create a ResNeXt model ([reference](https://arxiv.org/abs/1611.05431)). @@ -42,7 +40,7 @@ function resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @. inplanes = 64 baseplanes = 128 layers = [] - append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3))) + push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3))) push!(layers, MaxPool((3, 3), stride = (2, 2), pad = (1, 1))) for (i, nrepeats) in enumerate(block_config) # output planes within a block @@ -62,13 +60,13 @@ function resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @. width *= widen_factor end - return Chain(Chain(layers...), + return Chain(Chain(layers), Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(inplanes, nclasses))) end """ ResNeXt(cardinality, width; block_config, nclasses = 1000) - + Create a ResNeXt model ([reference](https://arxiv.org/abs/1611.05431)). diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index a0e63d689..6cc9dab83 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -16,7 +16,7 @@ function vgg_block(ifilters, ofilters, depth, batchnorm) layers = [] for _ in 1:depth if batchnorm - append!(layers, conv_bn(k, ifilters, ofilters; pad = p, bias = false)) + push!(layers, conv_bn(k, ifilters, ofilters; pad = p, bias = false)) else push!(layers, Conv(k, ifilters => ofilters, relu, pad = p)) end @@ -62,15 +62,12 @@ Create VGG classifier (fully connected) layers - `dropout`: the dropout level between each fully connected layer """ function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) - layers = [] - push!(layers, MLUtils.flatten) - push!(layers, Dense(Int(prod(imsize)), fcsize, relu)) - push!(layers, Dropout(dropout)) - push!(layers, Dense(fcsize, fcsize, relu)) - push!(layers, Dropout(dropout)) - push!(layers, Dense(fcsize, nclasses)) - - return layers + return Chain(MLUtils.flatten, + Dense(Int(prod(imsize)), fcsize, relu), + Dropout(dropout), + Dense(fcsize, fcsize, relu), + Dropout(dropout), + Dense(fcsize, nclasses)) end """ @@ -94,7 +91,7 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout) - return Chain(Chain(conv...), Chain(class...)) + return Chain(Chain(conv), class) end const vgg_conv_config = Dict(:A => [(64,1), (128,1), (256,2), (512,2), (512,2)], @@ -133,7 +130,7 @@ function VGG(imsize::Dims{2}; nclasses = nclasses, fcsize = fcsize, dropout = dropout) - + VGG(layers) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 1cae36b6e..78b729c01 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -45,7 +45,7 @@ function conv_bn(kernelsize, inplanes, outplanes, activation = relu; push!(layers, BatchNorm(Int(bnplanes), activations.bn; initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum)) - return rev ? reverse(layers) : layers + return rev ? Chain(reverse(layers)) : Chain(layers) end """ @@ -82,13 +82,13 @@ depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu; initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1, stride = 1, kwargs...) = - vcat(conv_bn(kernelsize, inplanes, inplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, - ϵ = ϵ, momentum = momentum, - stride = stride, groups = Int(inplanes), kwargs...), - conv_bn((1, 1), inplanes, outplanes, activation; - rev = rev, initβ = initβ, initγ = initγ, - ϵ = ϵ, momentum = momentum)) + Chain(vcat(conv_bn(kernelsize, inplanes, inplanes, activation; + rev = rev, initβ = initβ, initγ = initγ, + ϵ = ϵ, momentum = momentum, + stride = stride, groups = Int(inplanes), kwargs...), + conv_bn((1, 1), inplanes, outplanes, activation; + rev = rev, initβ = initβ, initγ = initγ, + ϵ = ϵ, momentum = momentum))) """ skip_projection(inplanes, outplanes, downsample = false) @@ -101,9 +101,9 @@ Create a skip projection - `outplanes`: the number of output feature maps - `downsample`: set to `true` to downsample the input """ -skip_projection(inplanes, outplanes, downsample = false) = downsample ? - Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) : - Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...) +skip_projection(inplanes, outplanes, downsample = false) = downsample ? + conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false) : + conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false) # array -> PaddedView(0, array, outplanes) for zero padding arrays """ @@ -144,8 +144,8 @@ Squeeze and excitation layer used by MobileNet variants function squeeze_excite(channels, reduction = 4) @assert (reduction >= 1) "`reduction` must be >= 1" SkipConnection(Chain(AdaptiveMeanPool((1, 1)), - conv_bn((1, 1), channels, channels ÷ reduction, relu; bias = false)..., - conv_bn((1, 1), channels ÷ reduction, channels, hardσ)...), .*) + conv_bn((1, 1), channels, channels ÷ reduction, relu; bias = false), + conv_bn((1, 1), channels ÷ reduction, channels, hardσ)), .*) end """ @@ -171,14 +171,14 @@ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activ @assert stride in [1, 2] "`stride` has to be 1 or 2" pad = @. (kernel_size - 1) ÷ 2 - conv1 = (inplanes == hidden_planes) ? () : conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false) + conv1 = (inplanes == hidden_planes) ? identity : conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false) selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes, reduction) - invres = Chain(conv1..., + invres = Chain(conv1, conv_bn(kernel_size, hidden_planes, hidden_planes, activation; - bias = false, stride, pad = pad, groups = hidden_planes)..., + bias = false, stride, pad = pad, groups = hidden_planes), selayer, - conv_bn((1, 1), hidden_planes, outplanes, identity; bias = false)...) + conv_bn((1, 1), hidden_planes, outplanes, identity; bias = false)) (stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres end diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index afc6d868d..135a1de86 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -13,7 +13,7 @@ patches. - `inchannels`: the number of channels in the input image - `patch_size`: the size of the patches - `embedplanes`: the number of channels in the embedding -- `norm_layer`: the normalization layer - by default the identity function but otherwise takes a +- `norm_layer`: the normalization layer - by default the identity function but otherwise takes a single argument constructor for a normalization layer like LayerNorm or BatchNorm - `flatten`: set true to flatten the input spatial dimensions after the embedding """ diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index a0c13d82b..f6e403134 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -56,8 +56,8 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, norm_laye npatches = prod(imsize .÷ patch_size) dp_rates = LinRange{Float32}(0., drop_path_rate, depth) layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), - [block(embedplanes, npatches; drop_path_rate = dp_rates[i], kwargs...) - for i in 1:depth]...) + Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], kwargs...) + for i in 1:depth])) classification_head = Chain(norm_layer(embedplanes), seconddimmean, Dense(embedplanes, nclasses)) return Chain(layers, classification_head) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index f49b42be2..55b3e3d30 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -17,7 +17,7 @@ function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0 SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); dropout)), +)) for _ in 1:depth] - Chain(layers...) + Chain(layers) end """ diff --git a/test/convnets.jl b/test/convnets.jl index 9470a975d..4c2c6026f 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -11,6 +11,8 @@ PRETRAINED_MODELS = [] @test_skip gradtest(model, rand(Float32, 256, 256, 3, 1)) end +GC.gc() + @testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] m = VGG(sz, batchnorm = bn) @@ -25,6 +27,8 @@ end end end +GC.gc() + @testset "ResNet" begin @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] m = ResNet(sz) @@ -47,6 +51,8 @@ end end end +GC.gc() + @testset "ResNeXt" begin @testset for depth in [50, 101, 152] m = ResNeXt(depth) @@ -61,6 +67,8 @@ end end end +GC.gc() + @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1) @@ -68,6 +76,8 @@ end @test_skip gradtest(m, rand(Float32, 224, 224, 3, 1)) end +GC.gc() + @testset "Inception3" begin m = Inception3() @test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1) @@ -75,6 +85,8 @@ end @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) end +GC.gc() + @testset "SqueezeNet" begin m = SqueezeNet() @test size(m(rand(Float32, 224, 224, 3, 1))) == (1000, 1) @@ -147,7 +159,7 @@ end GC.gc() @testset "ConvNeXt" verbose = true begin - @testset for mode in [:tiny, :small, :base, :large, :xlarge] + @testset for mode in [:tiny, :small, :base, :large] #, :xlarge] @testset for drop_path_rate in [0.0, 0.5, 0.99] m = ConvNeXt(mode; drop_path_rate)