-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add split function #1237
add split function #1237
Conversation
nx/lib/nx.ex
Outdated
axis = | ||
cond do | ||
is_integer(axis) and axis in axis_values -> | ||
axis | ||
|
||
is_atom(axis) and axis in axis_names -> | ||
dimensions = Enum.zip(axis_names, axis_values) | ||
dimensions[axis] | ||
|
||
true -> | ||
raise ":axis is out of tensor bounds." | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe there's a Nx.Shape.normalize_axis function for doing this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't understand how to use normalize the simple way I would like.
nx/lib/nx.ex
Outdated
end | ||
|
||
values = elem(shape, axis) | ||
size = values - split |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when size is negative? Do we want to allow negative splitting (which counts from the back)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we want this behavior.
688a8e0
to
0264f4d
Compare
@polvalente there's some weird test failing |
> | ||
""" | ||
@doc type: :indexed | ||
def split(tensor, split, opts \\ []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tiagodavi @josevalim I've pushed a refactor to collapse both clauses into the same one, mostly to highlight that the slicing is the same, with the only change being how we calculate it.
I've also changed the examples to operate on a written-out iota instead of arbitrary values because it makes it easier to compare results.
We still need to discuss if we want to accept a negative integer split.
It's easy to do it by setting right after axis_size =
: split = if is_integer(split) and split < 0, do: axis_size + split, else: split
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I think that negative split might be confusing in the sense that we don't know if the results will be reversed or swapped in the result tuple (either or both)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@polvalente we will base it on Enum.split
(which means looking up from the end).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I am fine with postponing this for now. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, so given this, my suggestion suffices:
iex(2)> Enum.split(1..10, 10 - 1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'}
iex(3)> Enum.split(1..10, -1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdyt?
Awesome, thank you folks. |
based on this: elixir-nx/scholar#107