diff --git a/nx/.gitignore b/nx/.gitignore index 1d7ddad4d5..fb5e8488ed 100644 --- a/nx/.gitignore +++ b/nx/.gitignore @@ -21,3 +21,6 @@ erl_crash.dump # Ignore package tarball (built via "mix hex.build"). nx-*.tar + +# ASDF files +.tool-versions diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 376da82749..4cca64f9d2 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -3265,6 +3265,138 @@ defmodule Nx do end) end + @doc ~S""" + Split a tensor into train and test subsets. + + `split` must be defined so that there are no empty result tensors. + This means that `split` must be: + + * an integer such that `0 < split` and `split < axis_size` + * a float such that `0.0 < split` and `ceil(axis_size * split) < axis_size` + + ## Options + + * `:axis` - The axis along which to split the tensor. Defaults to `0`. + + ## Examples + + All examples will operate on the same tensor so that it's easier to compare different configurations. + + iex> t = Nx.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + iex> {left, right} = Nx.split(t, 2, axis: 0) + iex> left + #Nx.Tensor< + s64[2][4] + [ + [0, 1, 2, 3], + [4, 5, 6, 7] + ] + > + iex> right + #Nx.Tensor< + s64[1][4] + [ + [8, 9, 10, 11] + ] + > + iex> {left, right} = Nx.split(t, 2, axis: 1) + iex> left + #Nx.Tensor< + s64[3][2] + [ + [0, 1], + [4, 5], + [8, 9] + ] + > + iex> right + #Nx.Tensor< + s64[3][2] + [ + [2, 3], + [6, 7], + [10, 11] + ] + > + + iex> t = Nx.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + iex> {left, right} = Nx.split(t, 0.5, axis: 0) + iex> left + #Nx.Tensor< + s64[2][4] + [ + [0, 1, 2, 3], + [4, 5, 6, 7] + ] + > + iex> right + #Nx.Tensor< + s64[1][4] + [ + [8, 9, 10, 11] + ] + > + iex> {left, right} = Nx.split(t, 0.75, axis: 1) + iex> left + #Nx.Tensor< + s64[3][3] + [ + [0, 1, 2], + [4, 5, 6], + [8, 9, 10] + ] + > + iex> right + #Nx.Tensor< + s64[3][1] + [ + [3], + [7], + [11] + ] + > + """ + @doc type: :indexed + def split(tensor, split, opts \\ []) + + def split(tensor, split, opts) do + tensor = to_tensor(tensor) + opts = keyword!(opts, axis: 0) + axis = Keyword.fetch!(opts, :axis) + + axis = Nx.Shape.normalize_axis(tensor.shape, axis, tensor.names) + axis_size = axis_size(tensor, axis) + + # only used in case the split is a float + float_split_index = Kernel.ceil(split * axis_size) + + {split_index, remainder_length} = + cond do + is_integer(split) and split > 0 and split < axis_size -> + {split, axis_size - split} + + is_integer(split) -> + raise ArgumentError, + "split must be an integer greater than zero and less than the length of the given axis" + + is_float(split) and float_split_index > 0 and float_split_index < axis_size -> + {float_split_index, axis_size - float_split_index} + + is_float(split) -> + raise ArgumentError, + "split must be a float such that 0 < split and ceil(split * axis_size) < 1" + + true -> + raise ArgumentError, + "invalid split received, expected a float or an integer, got: #{inspect(split)}" + end + + { + slice_along_axis(tensor, 0, split_index, axis: axis), + slice_along_axis(tensor, split_index, remainder_length, axis: axis) + } + end + @doc """ Broadcasts `tensor` to the given `broadcast_shape`. diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 19c88e0c6c..68ac9a4a98 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -3022,4 +3022,105 @@ defmodule NxTest do assert_all_close(zeros, Nx.imag(x_ifft), atol: 1.0e-8) end end + + describe "split/2" do + test "split is less than zero" do + tensor = Nx.iota({10, 2}, names: [:x, :y]) + + assert_raise ArgumentError, + "split must be an integer greater than zero and less than the length of the given axis", + fn -> + Nx.split(tensor, -1) + end + end + + test "split is a float out of bounds" do + tensor = Nx.iota({10, 2}, names: [:x, :y]) + + assert_raise ArgumentError, + "split must be a float such that 0 < split and ceil(split * axis_size) < 1", + fn -> + Nx.split(tensor, 1.0) + end + end + + test "split is greater than tensor length" do + tensor = Nx.iota({10, 2}, names: [:x, :y]) + + assert_raise ArgumentError, + "split must be an integer greater than zero and less than the length of the given axis", + fn -> + Nx.split(tensor, 3, axis: 1) + end + end + + test "axis is out of tensor bounds" do + tensor = Nx.iota({10, 2}, names: [:x, :y]) + + assert_raise ArgumentError, + "given axis (2) invalid for shape with rank 2", + fn -> + Nx.split(tensor, 2, axis: 2) + end + end + + test "named axis is invalid" do + tensor = Nx.iota({10, 2}, names: [:x, :y]) + + assert_raise ArgumentError, + "name :z not found in tensor with names [:x, :y]", + fn -> + Nx.split(tensor, 2, axis: :z) + end + end + + test "split into 50% for training and 50% for testing with floats on columns" do + tensor = Nx.iota({4, 4}, names: [:rows, :columns]) + {train, test} = Nx.split(tensor, 0.5, axis: :columns) + + assert {4, 2} == Nx.shape(train) + assert {4, 2} == Nx.shape(test) + end + + test "split into 70% for training and 30% for testing along a named :axis" do + tensor = Nx.iota({100, 6}, names: [:rows, :columns]) + {train, test} = Nx.split(tensor, 70, axis: :rows) + + assert {70, 6} == Nx.shape(train) + assert {30, 6} == Nx.shape(test) + end + + test "split into 90% for training and 10% for testing along a named :axis" do + tensor = Nx.iota({2, 100}, names: [:rows, :columns]) + {train, test} = Nx.split(tensor, 90, axis: :columns) + + assert {2, 90} == Nx.shape(train) + assert {2, 10} == Nx.shape(test) + end + + test "split into 50% for training and 50% for testing along the :axis 1" do + tensor = Nx.iota({100, 10}) + {train, test} = Nx.split(tensor, 5, axis: 1) + + assert {100, 5} == Nx.shape(train) + assert {100, 5} == Nx.shape(test) + end + + test "split into 61% for training and 39% for testing" do + tensor = Nx.iota({100, 10}) + {train, test} = Nx.split(tensor, 61) + + assert {61, 10} == Nx.shape(train) + assert {39, 10} == Nx.shape(test) + end + + test "split into 60% for training and 40% for testing with unbalanced data" do + tensor = Nx.iota({99, 4}) + + {train, test} = Nx.split(tensor, 60) + + assert {60, 4} == Nx.shape(train) + assert {39, 4} == Nx.shape(test) + end + end end