Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resize methods with kernel #554

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2041,32 +2041,35 @@ defmodule Axon.Layers do
deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do
for axis <- spatial_axes, reduce: input do
input ->
input_shape = Nx.shape(input)
input_size = elem(input_shape, axis)
output_size = elem(out_shape, axis)
resize_axis_with_kernel(input,
axis: axis,
output_size: elem(out_shape, axis),
kernel_fun: kernel_fun
)
end
end

inv_scale = input_size / output_size
kernel_scale = Nx.max(1, inv_scale)
defnp resize_axis_with_kernel(input, opts) do
axis = opts[:axis]
output_size = opts[:output_size]
kernel_fun = opts[:kernel_fun]

sample_f =
Nx.add(Nx.iota({1, output_size}), 0.5) |> Nx.multiply(Nx.subtract(inv_scale, 0.5))
input_size = Nx.axis_size(input, axis)

x = Nx.abs(Nx.subtract(sample_f, Nx.iota({input_size, 1}))) |> Nx.divide(kernel_scale)
weights = kernel_fun.(x)
inv_scale = input_size / output_size
kernel_scale = max(1, inv_scale)

weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)
sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale
weights = kernel_fun.(x)

weights =
Nx.select(
Nx.greater(Nx.abs(weights), 1000 * @f32_eps),
safe_divide(weights, weights_sum),
0
)
weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)

input = Nx.dot(input, [axis], weights, [0])
# The transformed axis is moved to the end, so we transpose back
reorder_axis(input, -1, axis)
end
weights = Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0)

input = Nx.dot(input, [axis], weights, [0])
# The transformed axis is moved to the end, so we transpose back
reorder_axis(input, -1, axis)
end

defnp fill_linear_kernel(x) do
Expand Down
65 changes: 65 additions & 0 deletions test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,71 @@ defmodule Axon.LayersTest do
Axon.Layers.resize(inp)
end
end

# Adapted from NxImage
test "methods" do
# Reference values computed in jax

image = Nx.iota({1, 2, 2, 3}, type: :f32)

assert_equal(
Axon.Layers.resize(image, size: {3, 3}, method: :nearest),
Nx.tensor([
[
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0], [9.0, 10.0, 11.0]]
]
])
)

assert_equal(
Axon.Layers.resize(image, size: {3, 3}, method: :bilinear),
Nx.tensor([
[
[[0.0, 1.0, 2.0], [1.5, 2.5, 3.5], [3.0, 4.0, 5.0]],
[[3.0, 4.0, 5.0], [4.5, 5.5, 6.5], [6.0, 7.0, 8.0]],
[[6.0, 7.0, 8.0], [7.5, 8.5, 9.5], [9.0, 10.0, 11.0]]
]
])
)

assert_all_close(
Axon.Layers.resize(image, size: {3, 3}, method: :bicubic),
Nx.tensor([
[
[[-0.5921, 0.4079, 1.4079], [1.1053, 2.1053, 3.1053], [2.8026, 3.8026, 4.8026]],
[[2.8026, 3.8026, 4.8026], [4.5, 5.5, 6.5], [6.1974, 7.1974, 8.1974]],
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
]
]),
atol: 1.0e-4
)

assert_all_close(
Axon.Layers.resize(image, size: {3, 3}, method: :lanczos3),
Nx.tensor([
[
[[-1.1173, -0.1173, 0.8827], [0.7551, 1.7551, 2.7551], [2.6276, 3.6276, 4.6276]],
[[2.6276, 3.6276, 4.6276], [4.5, 5.5, 6.5], [6.3724, 7.3724, 8.3724]],
[[6.3724, 7.3724, 8.3724], [8.2449, 9.2449, 10.2449], [10.1173, 11.1173, 12.1173]]
]
]),
atol: 1.0e-4
)

assert_all_close(
Axon.Layers.resize(image, size: {3, 3}, method: :lanczos5),
Nx.tensor([
[
[[-1.3525, -0.3525, 0.6475], [0.5984, 1.5984, 2.5984], [2.5492, 3.5492, 4.5492]],
[[2.5492, 3.5492, 4.5492], [4.5, 5.5, 6.5], [6.4508, 7.4508, 8.4508]],
[[6.4508, 7.4508, 8.4508], [8.4016, 9.4016, 10.4016], [10.3525, 11.3525, 12.3525]]
]
]),
atol: 1.0e-4
)
end
end

describe "lstm_cell" do
Expand Down
Loading