-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default integers to 32-bit precision #1524
Conversation
Our torchx is also old. Those may be fixed if we update it. |
One VM crash was related to specific dot product with u32: u = Nx.tensor([[[1]], [[2]]])
v = Nx.tensor([[[3]], [[4]]])
Nx.dot(u, [2], [0], v, [2], [0]) I changed the default libtorch version from 2.0.0 to 2.1.0, and it's fixed. The other crashes I fixed by casing in appropriate places. |
defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)} | ||
defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@polvalente after upgrading to libtroch 2.1.0, I saw a bunch of warnings like this:
[W Resize.cpp:35] Warning: An output with one or more elements was resized since it had shape [], which does not match the required output shape [3]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function _resize_output_check)
An example where this happens is:
Nx.logical_and(Nx.tensor(1), Nx.tensor([1, 2, 3]))
The warning only happens if the first operand is a scalar. Doing the actual broadcasting removes the warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do the broadcasting explicitly in StableHLO too, so that's fine I think.
I noticed that
It was already the case before with Perhaps there's a way to catch it and raise an elixir error instead, but that's not related to this PR. |
@polvalente feel free to merge, if you are ok with the torchx changes :) |
@@ -220,7 +220,7 @@ defmodule EXLA do | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EXLA has a few intermediate tensors being built as s64 (see Value.eigh for example). We should also check if those can be changed to s32 as well. Not a blocker for this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think those are unrelated, we just want to pass fixed integers as XLA inputs. We could actually make those unsigned (and change to uint on the c++ side), since all of those are non-negative sizes. Doesn't matter much in this case, your call!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened a PR to change to u64 #1526.
defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)} | ||
defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do the broadcasting explicitly in StableHLO too, so that's fine I think.
@@ -65,7 +65,7 @@ defmodule Torchx.Nx.RandomTest do | |||
# Output does not match Nx because of the sign of the remainder. | |||
distribution_case(:randint_split, | |||
args: [0, 10, [shape: {5}]], | |||
expected: Nx.tensor([3, 2, 6, 0, 0], type: :s64) | |||
expected: Nx.tensor([1, 1, 4, 1, 9], type: :s64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't really understand why these values changed, but that's ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because in args: [0, 10, [shape: {5}]]
passed to defn the numbers 0
and 10
are now passed as s32 rather than s64. We could maintain the behaviour with args: [Nx.s64(0), Nx.s64(10), [shape: {5}]]
, but changing the test is equally fine.
Nx and EXLA passes, but there are Torchx segfaults that I need to debug.