Skip to content

Commit

Permalink
Fix conv transpose with channels last (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Dec 21, 2022
1 parent 477cff3 commit 901bf85
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
6 changes: 3 additions & 3 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,9 @@ defmodule Axon.Layers do

padding =
transform(
{Nx.shape(kernel), opts[:kernel_dilation], strides, opts[:padding]},
fn {shape, k_dilation, strides, padding} ->
Axon.Shape.conv_transpose_padding(shape, k_dilation, strides, padding)
{Nx.shape(kernel), opts[:kernel_dilation], strides, opts[:padding], opts[:channels]},
fn {shape, k_dilation, strides, padding, channels} ->
Axon.Shape.conv_transpose_padding(shape, k_dilation, strides, padding, channels)
end
)

Expand Down
18 changes: 13 additions & 5 deletions lib/axon/shape.ex
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,20 @@ defmodule Axon.Shape do
@doc """
Calculates the padding needed for a transposed convolution.
"""
def conv_transpose_padding(kernel_shape, kernel_dilation, strides, padding)
def conv_transpose_padding(kernel_shape, kernel_dilation, strides, padding, channels)
when padding in [:valid, :same] do
kernel_spatial_dims =
kernel_shape
|> Tuple.delete_at(0)
|> Tuple.delete_at(0)
case channels do
:first ->
kernel_shape
|> Tuple.delete_at(0)
|> Tuple.delete_at(0)

:last ->
kernel_shape
|> Tuple.delete_at(tuple_size(kernel_shape) - 1)
|> Tuple.delete_at(tuple_size(kernel_shape) - 2)
end

kernel_dilation =
if is_list(kernel_dilation),
Expand Down Expand Up @@ -387,7 +395,7 @@ defmodule Axon.Shape do
end
end

def conv_transpose_padding(_, _, _, padding), do: padding
def conv_transpose_padding(_, _, _, padding, _), do: padding

@doc """
Calculates the shape of a depthwise convolution kernel given the
Expand Down
13 changes: 13 additions & 0 deletions test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,19 @@ defmodule Axon.LayersTest do
end

describe "conv_transpose" do
test "channels first same as channels last" do
input = Nx.random_uniform({1, 1, 28, 28})
t_input = Nx.transpose(input, axes: [0, 2, 3, 1])
kernel = Nx.random_uniform({3, 1, 4, 4})
t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0])
bias = Nx.tensor(0.0)

first = Axon.Layers.conv_transpose(input, kernel, bias, channels: :first)
last = Axon.Layers.conv_transpose(t_input, t_kernel, bias, channels: :last)

assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2]))
end

test "correct valid padding, no strides" do
inp = Nx.iota({1, 1, 4}, type: {:f, 32})
kernel = Nx.iota({3, 1, 2}, type: {:f, 32})
Expand Down

0 comments on commit 901bf85

Please sign in to comment.