Skip to content

Commit

Permalink
improve again
Browse files Browse the repository at this point in the history
  • Loading branch information
tiagodavi committed Jun 2, 2023
1 parent 3f83739 commit 92eb87d
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 51 deletions.
91 changes: 81 additions & 10 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3268,7 +3268,7 @@ defmodule Nx do
@doc ~S"""
Split a tensor into train and test subsets.
`split` must be a float number between `0.0` and `1.0`.
`split` split must be an integer greater than zero and less than the length of the tensor.
## Options
Expand All @@ -3278,7 +3278,7 @@ defmodule Nx do
Split a tensor into two separate tensors.
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 0.8, axis: 0)
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 2, axis: 0)
iex> train
#Nx.Tensor<
s64[2][3]
Expand All @@ -3294,6 +3294,63 @@ defmodule Nx do
[23, 4, 1]
]
>
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 2, axis: 1)
iex> train
#Nx.Tensor<
s64[3][2]
[
[3, 6],
[26, 75],
[23, 4]
]
>
iex> test
#Nx.Tensor<
s64[3][1]
[
[5],
[3],
[1]
]
>
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :columns)
iex> train
#Nx.Tensor<
s64[rows: 3][columns: 2]
[
[3, 6],
[26, 75],
[23, 4]
]
>
iex> test
#Nx.Tensor<
s64[rows: 3][columns: 2]
[
[5, 20],
[3, 9],
[1, 5]
]
>
iex>{train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :rows)
iex> train
#Nx.Tensor<
s64[rows: 2][columns: 4]
[
[3, 6, 5, 20],
[26, 75, 3, 9]
]
>
iex> test
#Nx.Tensor<
s64[rows: 1][columns: 4]
[
[23, 4, 1, 5]
]
>
"""
@doc type: :indexed
def split(tensor, split, opts \\ [])
Expand All @@ -3302,18 +3359,32 @@ defmodule Nx do
opts = keyword!(opts, axis: 0)
axis = Keyword.fetch!(opts, :axis)

if is_float(split) and split > 0.0 and split < 1.0 do
rows = elem(shape, 0)
split_size = Kernel.floor(split * rows)
split_size = if split_size < 1, do: 1, else: split_size
remaining_size = rows - split_size
if is_integer(split) and split > 0 do
axis_values = axes(tensor)
axis_names = names(tensor)

axis =
cond do
is_integer(axis) and axis in axis_values ->
axis

is_atom(axis) and axis in axis_names ->
dimensions = Enum.zip(axis_names, axis_values)
dimensions[axis]

true ->
raise ":axis is out of tensor bounds."
end

values = elem(shape, axis)
size = values - split

{
slice_along_axis(tensor, 0, split_size, axis: axis),
slice_along_axis(tensor, split_size, remaining_size, axis: axis)
slice_along_axis(tensor, 0, split, axis: axis),
slice_along_axis(tensor, split, size, axis: axis)
}
else
raise ":split must be a float number between 0.0 and 1.0"
raise "split must be an integer greater than zero and less than the length of the tensor."
end
end

Expand Down
104 changes: 63 additions & 41 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2961,63 +2961,85 @@ defmodule NxTest do
end

describe "split/2" do
test "Split list into 50% for training and 50% for testing" do
test "split is less than zero" do
tensor = Nx.iota({10, 2}, names: [:x, :y])
{train, test} = Nx.split(tensor, 0.5)

assert Nx.tensor(
[
[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]
],
names: [:x, :y]
) == train

assert Nx.tensor(
[
[10, 11],
[12, 13],
[14, 15],
[16, 17],
[18, 19]
],
names: [:x, :y]
) == test
assert_raise RuntimeError,
"split must be an integer greater than zero and less than the length of the tensor.",
fn ->
Nx.split(tensor, -1)
end
end

test "split is greater than tensor length" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"length at axis 1 must be less than axis size of 2, got: 3",
fn ->
Nx.split(tensor, 3, axis: 1)
end
end

test "axis is out of tensor bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise RuntimeError,
":axis is out of tensor bounds.",
fn ->
Nx.split(tensor, 2, axis: 2)
end
end

test "named axis is out of tensor bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise RuntimeError,
":axis is out of tensor bounds.",
fn ->
Nx.split(tensor, 2, axis: :z)
end
end

test "Split into 70% for training and 30% for testing" do
tensor = Nx.iota({100, 6})
{train, test} = Nx.split(tensor, 0.7)
test "split into 70% for training and 30% for testing along a named :axis" do
tensor = Nx.iota({100, 6}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 70, axis: :rows)

assert length(Nx.to_list(train)) == 70
assert length(Nx.to_list(test)) == 30
assert {70, 6} == Nx.shape(train)
assert {30, 6} == Nx.shape(test)
end

test "Split into 75% for training and 25% for testing" do
test "split into 90% for training and 10% for testing along a named :axis" do
tensor = Nx.iota({2, 100}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 90, axis: :columns)

assert {2, 90} == Nx.shape(train)
assert {2, 10} == Nx.shape(test)
end

test "split into 50% for training and 50% for testing along the :axis 1" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 0.75)
{train, test} = Nx.split(tensor, 5, axis: 1)

assert length(Nx.to_list(train)) == 75
assert length(Nx.to_list(test)) == 25
assert {100, 5} == Nx.shape(train)
assert {100, 5} == Nx.shape(test)
end

test "Split into 61% for training and 39% for testing" do
test "split into 61% for training and 39% for testing" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 0.61)
{train, test} = Nx.split(tensor, 61)

assert length(Nx.to_list(train)) == 61
assert length(Nx.to_list(test)) == 39
assert {61, 10} == Nx.shape(train)
assert {39, 10} == Nx.shape(test)
end

test "Split into 60% for training and 40% for testing with unbalanced data" do
tensor = Nx.iota({73, 4})
{train, test} = Nx.split(tensor, 0.61)
test "split into 60% for training and 40% for testing with unbalanced data" do
tensor = Nx.iota({99, 4})

{train, test} = Nx.split(tensor, 60)

assert length(Nx.to_list(train)) == 44
assert length(Nx.to_list(test)) == 29
assert {60, 4} == Nx.shape(train)
assert {39, 4} == Nx.shape(test)
end
end
end

0 comments on commit 92eb87d

Please sign in to comment.