Skip to content

Commit

Permalink
add both integer and float implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
tiagodavi committed Jun 13, 2023
1 parent 92eb87d commit 688a8e0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 45 deletions.
102 changes: 57 additions & 45 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3268,7 +3268,7 @@ defmodule Nx do
@doc ~S"""
Split a tensor into train and test subsets.
`split` split must be an integer greater than zero and less than the length of the tensor.
`split` split must be either an integer greater than zero and less than the length of the tensor or a float number between `0.0` and `1.0`.
## Options
Expand Down Expand Up @@ -3315,67 +3315,33 @@ defmodule Nx do
]
>
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :columns)
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 9], [26, 75, 3, 20], [23, 4, 1, 56], [40, 6, 78, 94]]), 0.5, axis: 0)
iex> train
#Nx.Tensor<
s64[rows: 3][columns: 2]
s64[2][4]
[
[3, 6],
[26, 75],
[23, 4]
]
>
iex> test
#Nx.Tensor<
s64[rows: 3][columns: 2]
[
[5, 20],
[3, 9],
[1, 5]
]
>
iex>{train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :rows)
iex> train
#Nx.Tensor<
s64[rows: 2][columns: 4]
[
[3, 6, 5, 20],
[26, 75, 3, 9]
[3, 6, 5, 9],
[26, 75, 3, 20]
]
>
iex> test
#Nx.Tensor<
s64[rows: 1][columns: 4]
s64[2][4]
[
[23, 4, 1, 5]
[23, 4, 1, 56],
[40, 6, 78, 94]
]
>
"""
@doc type: :indexed
def split(tensor, split, opts \\ [])

def split(%T{shape: shape} = tensor, split, opts) do
def split(%T{shape: shape} = tensor, split, opts) when is_integer(split) do
opts = keyword!(opts, axis: 0)
axis = Keyword.fetch!(opts, :axis)

if is_integer(split) and split > 0 do
axis_values = axes(tensor)
axis_names = names(tensor)

axis =
cond do
is_integer(axis) and axis in axis_values ->
axis

is_atom(axis) and axis in axis_names ->
dimensions = Enum.zip(axis_names, axis_values)
dimensions[axis]

true ->
raise ":axis is out of tensor bounds."
end

if split > 0 do
axis = find_axis(tensor, axis)
values = elem(shape, axis)
size = values - split

Expand All @@ -3388,6 +3354,35 @@ defmodule Nx do
end
end

def split(%T{shape: shape} = tensor, split, opts) when is_float(split) do
opts = keyword!(opts, axis: 0)
axis = Keyword.fetch!(opts, :axis)

if split > 0.0 and split < 1.0 do
axis = find_axis(tensor, axis)

values = elem(shape, axis)

split_size = Kernel.ceil(split * values)

split_size =
cond do
split_size < 1 -> 1
split_size >= values -> 1
true -> split_size
end

remaining_size = values - split_size

{
slice_along_axis(tensor, 0, split_size, axis: axis),
slice_along_axis(tensor, split_size, remaining_size, axis: axis)
}
else
raise "split must be a float number between 0.0 and 1.0."
end
end

@doc """
Broadcasts `tensor` to the given `broadcast_shape`.
Expand Down Expand Up @@ -16127,4 +16122,21 @@ defmodule Nx do
end)
end
end

defp find_axis(tensor, axis) do
axis_values = axes(tensor)
axis_names = names(tensor)

cond do
is_integer(axis) and axis in axis_values ->
axis

is_atom(axis) and axis in axis_names ->
dimensions = Enum.zip(axis_names, axis_values)
dimensions[axis]

true ->
raise ":axis is out of tensor bounds."
end
end
end
18 changes: 18 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2971,6 +2971,16 @@ defmodule NxTest do
end
end

test "split is a float out of bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise RuntimeError,
"split must be a float number between 0.0 and 1.0.",
fn ->
Nx.split(tensor, 1.0)
end
end

test "split is greater than tensor length" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

Expand Down Expand Up @@ -3001,6 +3011,14 @@ defmodule NxTest do
end
end

test "split into 50% for training and 50% for testing with floats on columns" do
tensor = Nx.iota({4, 4}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 0.5, axis: :columns)

assert {4, 2} == Nx.shape(train)
assert {4, 2} == Nx.shape(test)
end

test "split into 70% for training and 30% for testing along a named :axis" do
tensor = Nx.iota({100, 6}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 70, axis: :rows)
Expand Down

0 comments on commit 688a8e0

Please sign in to comment.