Skip to content

Commit

Permalink
add split function (#1237)
Browse files Browse the repository at this point in the history
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-authored-by: José Valim <jose.valim@gmail.com>
  • Loading branch information
3 people authored Jun 14, 2023
1 parent 13213c4 commit 31e3b2c
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
3 changes: 3 additions & 0 deletions nx/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ erl_crash.dump

# Ignore package tarball (built via "mix hex.build").
nx-*.tar

# ASDF files
.tool-versions
132 changes: 132 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3265,6 +3265,138 @@ defmodule Nx do
end)
end

@doc ~S"""
Split a tensor into train and test subsets.
`split` must be defined so that there are no empty result tensors.
This means that `split` must be:
* an integer such that `0 < split` and `split < axis_size`
* a float such that `0.0 < split` and `ceil(axis_size * split) < axis_size`
## Options
* `:axis` - The axis along which to split the tensor. Defaults to `0`.
## Examples
All examples will operate on the same tensor so that it's easier to compare different configurations.
iex> t = Nx.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
iex> {left, right} = Nx.split(t, 2, axis: 0)
iex> left
#Nx.Tensor<
s64[2][4]
[
[0, 1, 2, 3],
[4, 5, 6, 7]
]
>
iex> right
#Nx.Tensor<
s64[1][4]
[
[8, 9, 10, 11]
]
>
iex> {left, right} = Nx.split(t, 2, axis: 1)
iex> left
#Nx.Tensor<
s64[3][2]
[
[0, 1],
[4, 5],
[8, 9]
]
>
iex> right
#Nx.Tensor<
s64[3][2]
[
[2, 3],
[6, 7],
[10, 11]
]
>
iex> t = Nx.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
iex> {left, right} = Nx.split(t, 0.5, axis: 0)
iex> left
#Nx.Tensor<
s64[2][4]
[
[0, 1, 2, 3],
[4, 5, 6, 7]
]
>
iex> right
#Nx.Tensor<
s64[1][4]
[
[8, 9, 10, 11]
]
>
iex> {left, right} = Nx.split(t, 0.75, axis: 1)
iex> left
#Nx.Tensor<
s64[3][3]
[
[0, 1, 2],
[4, 5, 6],
[8, 9, 10]
]
>
iex> right
#Nx.Tensor<
s64[3][1]
[
[3],
[7],
[11]
]
>
"""
@doc type: :indexed
def split(tensor, split, opts \\ [])

def split(tensor, split, opts) do
tensor = to_tensor(tensor)
opts = keyword!(opts, axis: 0)
axis = Keyword.fetch!(opts, :axis)

axis = Nx.Shape.normalize_axis(tensor.shape, axis, tensor.names)
axis_size = axis_size(tensor, axis)

# only used in case the split is a float
float_split_index = Kernel.ceil(split * axis_size)

{split_index, remainder_length} =
cond do
is_integer(split) and split > 0 and split < axis_size ->
{split, axis_size - split}

is_integer(split) ->
raise ArgumentError,
"split must be an integer greater than zero and less than the length of the given axis"

is_float(split) and float_split_index > 0 and float_split_index < axis_size ->
{float_split_index, axis_size - float_split_index}

is_float(split) ->
raise ArgumentError,
"split must be a float such that 0 < split and ceil(split * axis_size) < 1"

true ->
raise ArgumentError,
"invalid split received, expected a float or an integer, got: #{inspect(split)}"
end

{
slice_along_axis(tensor, 0, split_index, axis: axis),
slice_along_axis(tensor, split_index, remainder_length, axis: axis)
}
end

@doc """
Broadcasts `tensor` to the given `broadcast_shape`.
Expand Down
101 changes: 101 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3022,4 +3022,105 @@ defmodule NxTest do
assert_all_close(zeros, Nx.imag(x_ifft), atol: 1.0e-8)
end
end

describe "split/2" do
test "split is less than zero" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"split must be an integer greater than zero and less than the length of the given axis",
fn ->
Nx.split(tensor, -1)
end
end

test "split is a float out of bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"split must be a float such that 0 < split and ceil(split * axis_size) < 1",
fn ->
Nx.split(tensor, 1.0)
end
end

test "split is greater than tensor length" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"split must be an integer greater than zero and less than the length of the given axis",
fn ->
Nx.split(tensor, 3, axis: 1)
end
end

test "axis is out of tensor bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"given axis (2) invalid for shape with rank 2",
fn ->
Nx.split(tensor, 2, axis: 2)
end
end

test "named axis is invalid" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"name :z not found in tensor with names [:x, :y]",
fn ->
Nx.split(tensor, 2, axis: :z)
end
end

test "split into 50% for training and 50% for testing with floats on columns" do
tensor = Nx.iota({4, 4}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 0.5, axis: :columns)

assert {4, 2} == Nx.shape(train)
assert {4, 2} == Nx.shape(test)
end

test "split into 70% for training and 30% for testing along a named :axis" do
tensor = Nx.iota({100, 6}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 70, axis: :rows)

assert {70, 6} == Nx.shape(train)
assert {30, 6} == Nx.shape(test)
end

test "split into 90% for training and 10% for testing along a named :axis" do
tensor = Nx.iota({2, 100}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 90, axis: :columns)

assert {2, 90} == Nx.shape(train)
assert {2, 10} == Nx.shape(test)
end

test "split into 50% for training and 50% for testing along the :axis 1" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 5, axis: 1)

assert {100, 5} == Nx.shape(train)
assert {100, 5} == Nx.shape(test)
end

test "split into 61% for training and 39% for testing" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 61)

assert {61, 10} == Nx.shape(train)
assert {39, 10} == Nx.shape(test)
end

test "split into 60% for training and 40% for testing with unbalanced data" do
tensor = Nx.iota({99, 4})

{train, test} = Nx.split(tensor, 60)

assert {60, 4} == Nx.shape(train)
assert {39, 4} == Nx.shape(test)
end
end
end

0 comments on commit 31e3b2c

Please sign in to comment.