Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split function #1237

Merged
merged 9 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nx/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ erl_crash.dump

# Ignore package tarball (built via "mix hex.build").
nx-*.tar

# ASDF files
.tool-versions
135 changes: 135 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3265,6 +3265,124 @@ defmodule Nx do
end)
end

@doc ~S"""
Split a tensor into train and test subsets.

`split` split must be either an integer greater than zero and less than the length of the tensor or a float number between `0.0` and `1.0`.

## Options

* `:axis` - The axis along which to take the values from. Defaults to `0`.

## Examples

Split a tensor into two separate tensors.

iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 2, axis: 0)
tiagodavi marked this conversation as resolved.
Show resolved Hide resolved
iex> train
#Nx.Tensor<
s64[2][3]
[
[3, 6, 5],
[26, 75, 3]
]
>
iex> test
#Nx.Tensor<
s64[1][3]
[
[23, 4, 1]
]
>

iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 2, axis: 1)
iex> train
#Nx.Tensor<
s64[3][2]
[
[3, 6],
[26, 75],
[23, 4]
]
>
iex> test
#Nx.Tensor<
s64[3][1]
[
[5],
[3],
[1]
]
>

iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 9], [26, 75, 3, 20], [23, 4, 1, 56], [40, 6, 78, 94]]), 0.5, axis: 0)
iex> train
#Nx.Tensor<
s64[2][4]
[
[3, 6, 5, 9],
[26, 75, 3, 20]
]
>
iex> test
#Nx.Tensor<
s64[2][4]
[
[23, 4, 1, 56],
[40, 6, 78, 94]
]
>
"""
@doc type: :indexed
def split(tensor, split, opts \\ [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiagodavi @josevalim I've pushed a refactor to collapse both clauses into the same one, mostly to highlight that the slicing is the same, with the only change being how we calculate it.

I've also changed the examples to operate on a written-out iota instead of arbitrary values because it makes it easier to compare results.

We still need to discuss if we want to accept a negative integer split.

It's easy to do it by setting right after axis_size = : split = if is_integer(split) and split < 0, do: axis_size + split, else: split

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I think that negative split might be confusing in the sense that we don't know if the results will be reversed or swapped in the result tuple (either or both)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@polvalente we will base it on Enum.split (which means looking up from the end).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I am fine with postponing this for now. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so given this, my suggestion suffices:

iex(2)> Enum.split(1..10, 10 - 1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'}
iex(3)> Enum.split(1..10, -1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt?


def split(%T{shape: shape} = tensor, split, opts) when is_integer(split) do
opts = keyword!(opts, axis: 0)
axis = Keyword.fetch!(opts, :axis)

if split > 0 do
axis = find_axis(tensor, axis)
values = elem(shape, axis)
size = values - split
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when size is negative? Do we want to allow negative splitting (which counts from the back)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we want this behavior.


{
slice_along_axis(tensor, 0, split, axis: axis),
slice_along_axis(tensor, split, size, axis: axis)
}
else
raise "split must be an integer greater than zero and less than the length of the tensor."
end
end

def split(%T{shape: shape} = tensor, split, opts) when is_float(split) do
opts = keyword!(opts, axis: 0)
axis = Keyword.fetch!(opts, :axis)

if split > 0.0 and split < 1.0 do
axis = find_axis(tensor, axis)

values = elem(shape, axis)

split_size = Kernel.ceil(split * values)

split_size =
cond do
split_size < 1 -> 1
split_size >= values -> 1
true -> split_size
end

remaining_size = values - split_size

{
slice_along_axis(tensor, 0, split_size, axis: axis),
slice_along_axis(tensor, split_size, remaining_size, axis: axis)
}
else
raise "split must be a float number between 0.0 and 1.0."
polvalente marked this conversation as resolved.
Show resolved Hide resolved
end
end

@doc """
Broadcasts `tensor` to the given `broadcast_shape`.

Expand Down Expand Up @@ -16061,4 +16179,21 @@ defmodule Nx do
end)
end
end

defp find_axis(tensor, axis) do
axis_values = axes(tensor)
axis_names = names(tensor)

cond do
is_integer(axis) and axis in axis_values ->
axis

is_atom(axis) and axis in axis_names ->
dimensions = Enum.zip(axis_names, axis_values)
dimensions[axis]

true ->
raise ":axis is out of tensor bounds."
end
end
end
101 changes: 101 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3022,4 +3022,105 @@ defmodule NxTest do
assert_all_close(zeros, Nx.imag(x_ifft), atol: 1.0e-8)
end
end

describe "split/2" do
test "split is less than zero" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise RuntimeError,
"split must be an integer greater than zero and less than the length of the tensor.",
fn ->
Nx.split(tensor, -1)
end
end

test "split is a float out of bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise RuntimeError,
"split must be a float number between 0.0 and 1.0.",
fn ->
Nx.split(tensor, 1.0)
end
end

test "split is greater than tensor length" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"length at axis 1 must be less than axis size of 2, got: 3",
fn ->
Nx.split(tensor, 3, axis: 1)
end
end

test "axis is out of tensor bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise RuntimeError,
":axis is out of tensor bounds.",
fn ->
Nx.split(tensor, 2, axis: 2)
end
end

test "named axis is out of tensor bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise RuntimeError,
":axis is out of tensor bounds.",
fn ->
Nx.split(tensor, 2, axis: :z)
end
end

test "split into 50% for training and 50% for testing with floats on columns" do
tensor = Nx.iota({4, 4}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 0.5, axis: :columns)

assert {4, 2} == Nx.shape(train)
assert {4, 2} == Nx.shape(test)
end

test "split into 70% for training and 30% for testing along a named :axis" do
tensor = Nx.iota({100, 6}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 70, axis: :rows)

assert {70, 6} == Nx.shape(train)
assert {30, 6} == Nx.shape(test)
end

test "split into 90% for training and 10% for testing along a named :axis" do
tensor = Nx.iota({2, 100}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 90, axis: :columns)

assert {2, 90} == Nx.shape(train)
assert {2, 10} == Nx.shape(test)
end

test "split into 50% for training and 50% for testing along the :axis 1" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 5, axis: 1)

assert {100, 5} == Nx.shape(train)
assert {100, 5} == Nx.shape(test)
end

test "split into 61% for training and 39% for testing" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 61)

assert {61, 10} == Nx.shape(train)
assert {39, 10} == Nx.shape(test)
end

test "split into 60% for training and 40% for testing with unbalanced data" do
tensor = Nx.iota({99, 4})

{train, test} = Nx.split(tensor, 60)

assert {60, 4} == Nx.shape(train)
assert {39, 4} == Nx.shape(test)
end
end
end