Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split function #1237

Merged
merged 9 commits into from
Jun 14, 2023
Merged

Conversation

tiagodavi
Copy link
Contributor

based on this: elixir-nx/scholar#107

nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/test/nx_test.exs Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
@tiagodavi tiagodavi requested a review from polvalente June 2, 2023 14:44
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated
Comment on lines 3366 to 3377
axis =
cond do
is_integer(axis) and axis in axis_values ->
axis

is_atom(axis) and axis in axis_names ->
dimensions = Enum.zip(axis_names, axis_values)
dimensions[axis]

true ->
raise ":axis is out of tensor bounds."
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there's a Nx.Shape.normalize_axis function for doing this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand how to use normalize the simple way I would like.

nx/lib/nx.ex Outdated
end

values = elem(shape, axis)
size = values - split
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when size is negative? Do we want to allow negative splitting (which counts from the back)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we want this behavior.

@tiagodavi tiagodavi force-pushed the td/add-split-function branch from 688a8e0 to 0264f4d Compare June 13, 2023 20:19
@tiagodavi
Copy link
Contributor Author

@polvalente there's some weird test failing
test/nx/serving_test.exs:637

nx/lib/nx.ex Outdated Show resolved Hide resolved
>
"""
@doc type: :indexed
def split(tensor, split, opts \\ [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiagodavi @josevalim I've pushed a refactor to collapse both clauses into the same one, mostly to highlight that the slicing is the same, with the only change being how we calculate it.

I've also changed the examples to operate on a written-out iota instead of arbitrary values because it makes it easier to compare results.

We still need to discuss if we want to accept a negative integer split.

It's easy to do it by setting right after axis_size = : split = if is_integer(split) and split < 0, do: axis_size + split, else: split

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I think that negative split might be confusing in the sense that we don't know if the results will be reversed or swapped in the result tuple (either or both)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@polvalente we will base it on Enum.split (which means looking up from the end).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I am fine with postponing this for now. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so given this, my suggestion suffices:

iex(2)> Enum.split(1..10, 10 - 1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'}
iex(3)> Enum.split(1..10, -1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt?

@polvalente polvalente merged commit 31e3b2c into elixir-nx:main Jun 14, 2023
@tiagodavi tiagodavi deleted the td/add-split-function branch June 14, 2023 11:58
@tiagodavi
Copy link
Contributor Author

Awesome, thank you folks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants