From 034764d3ae7e9450a729fa1dbd13f10dffcbede9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Fri, 16 Feb 2024 23:56:18 +0700 Subject: [PATCH 1/2] Fix resize methods with kernel --- lib/axon/layers.ex | 4 ++- test/axon/layers_test.exs | 65 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index 238c952c..d37e6968 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -2049,7 +2049,9 @@ defmodule Axon.Layers do kernel_scale = Nx.max(1, inv_scale) sample_f = - Nx.add(Nx.iota({1, output_size}), 0.5) |> Nx.multiply(Nx.subtract(inv_scale, 0.5)) + Nx.add(Nx.iota({1, output_size}), 0.5) + |> Nx.multiply(inv_scale) + |> Nx.subtract(0.5) x = Nx.abs(Nx.subtract(sample_f, Nx.iota({input_size, 1}))) |> Nx.divide(kernel_scale) weights = kernel_fun.(x) diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index 6ed1643c..3c9db59a 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -944,6 +944,71 @@ defmodule Axon.LayersTest do Axon.Layers.resize(inp) end end + + # Adapted from NxImage + test "methods" do + # Reference values computed in jax + + image = Nx.iota({1, 2, 2, 3}, type: :f32) + + assert_equal( + Axon.Layers.resize(image, size: {3, 3}, method: :nearest), + Nx.tensor([ + [ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0], [9.0, 10.0, 11.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0], [9.0, 10.0, 11.0]] + ] + ]) + ) + + assert_equal( + Axon.Layers.resize(image, size: {3, 3}, method: :bilinear), + Nx.tensor([ + [ + [[0.0, 1.0, 2.0], [1.5, 2.5, 3.5], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [4.5, 5.5, 6.5], [6.0, 7.0, 8.0]], + [[6.0, 7.0, 8.0], [7.5, 8.5, 9.5], [9.0, 10.0, 11.0]] + ] + ]) + ) + + assert_all_close( + Axon.Layers.resize(image, size: {3, 3}, method: :bicubic), + Nx.tensor([ + [ + [[-0.5921, 0.4079, 1.4079], [1.1053, 2.1053, 3.1053], [2.8026, 3.8026, 4.8026]], + [[2.8026, 3.8026, 4.8026], [4.5, 5.5, 6.5], [6.1974, 7.1974, 8.1974]], + [[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]] + ] + ]), + atol: 1.0e-4 + ) + + assert_all_close( + Axon.Layers.resize(image, size: {3, 3}, method: :lanczos3), + Nx.tensor([ + [ + [[-1.1173, -0.1173, 0.8827], [0.7551, 1.7551, 2.7551], [2.6276, 3.6276, 4.6276]], + [[2.6276, 3.6276, 4.6276], [4.5, 5.5, 6.5], [6.3724, 7.3724, 8.3724]], + [[6.3724, 7.3724, 8.3724], [8.2449, 9.2449, 10.2449], [10.1173, 11.1173, 12.1173]] + ] + ]), + atol: 1.0e-4 + ) + + assert_all_close( + Axon.Layers.resize(image, size: {3, 3}, method: :lanczos5), + Nx.tensor([ + [ + [[-1.3525, -0.3525, 0.6475], [0.5984, 1.5984, 2.5984], [2.5492, 3.5492, 4.5492]], + [[2.5492, 3.5492, 4.5492], [4.5, 5.5, 6.5], [6.4508, 7.4508, 8.4508]], + [[6.4508, 7.4508, 8.4508], [8.4016, 9.4016, 10.4016], [10.3525, 11.3525, 12.3525]] + ] + ]), + atol: 1.0e-4 + ) + end end describe "lstm_cell" do From b1f35e11b66672ffdcdc658fae01a04f109978ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Sat, 17 Feb 2024 00:13:49 +0700 Subject: [PATCH 2/2] Move more logic to defn --- lib/axon/layers.ex | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index d37e6968..facec674 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -2041,34 +2041,35 @@ defmodule Axon.Layers do deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do for axis <- spatial_axes, reduce: input do input -> - input_shape = Nx.shape(input) - input_size = elem(input_shape, axis) - output_size = elem(out_shape, axis) + resize_axis_with_kernel(input, + axis: axis, + output_size: elem(out_shape, axis), + kernel_fun: kernel_fun + ) + end + end - inv_scale = input_size / output_size - kernel_scale = Nx.max(1, inv_scale) + defnp resize_axis_with_kernel(input, opts) do + axis = opts[:axis] + output_size = opts[:output_size] + kernel_fun = opts[:kernel_fun] - sample_f = - Nx.add(Nx.iota({1, output_size}), 0.5) - |> Nx.multiply(inv_scale) - |> Nx.subtract(0.5) + input_size = Nx.axis_size(input, axis) - x = Nx.abs(Nx.subtract(sample_f, Nx.iota({input_size, 1}))) |> Nx.divide(kernel_scale) - weights = kernel_fun.(x) + inv_scale = input_size / output_size + kernel_scale = max(1, inv_scale) - weights_sum = Nx.sum(weights, axes: [0], keep_axes: true) + sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5 + x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale + weights = kernel_fun.(x) - weights = - Nx.select( - Nx.greater(Nx.abs(weights), 1000 * @f32_eps), - safe_divide(weights, weights_sum), - 0 - ) + weights_sum = Nx.sum(weights, axes: [0], keep_axes: true) - input = Nx.dot(input, [axis], weights, [0]) - # The transformed axis is moved to the end, so we transpose back - reorder_axis(input, -1, axis) - end + weights = Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0) + + input = Nx.dot(input, [axis], weights, [0]) + # The transformed axis is moved to the end, so we transpose back + reorder_axis(input, -1, axis) end defnp fill_linear_kernel(x) do