Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default integers to 32-bit precision #1524

Merged
merged 2 commits into from
Sep 3, 2024
Merged

Default integers to 32-bit precision #1524

merged 2 commits into from
Sep 3, 2024

Conversation

jonatanklosko
Copy link
Member

Nx and EXLA passes, but there are Torchx segfaults that I need to debug.

@josevalim
Copy link
Collaborator

Our torchx is also old. Those may be fixed if we update it.

@jonatanklosko
Copy link
Member Author

One VM crash was related to specific dot product with u32:

u = Nx.tensor([[[1]], [[2]]])
v = Nx.tensor([[[3]], [[4]]])
Nx.dot(u, [2], [0], v, [2], [0])

I changed the default libtorch version from 2.0.0 to 2.1.0, and it's fixed.

The other crashes I fixed by casing in appropriate places.

Comment on lines -802 to -803
defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)}
defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@polvalente after upgrading to libtroch 2.1.0, I saw a bunch of warnings like this:

[W Resize.cpp:35] Warning: An output with one or more elements was resized since it had shape [], which does not match the required output shape [3]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function _resize_output_check)

An example where this happens is:

Nx.logical_and(Nx.tensor(1), Nx.tensor([1, 2, 3]))

The warning only happens if the first operand is a scalar. Doing the actual broadcasting removes the warning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the broadcasting explicitly in StableHLO too, so that's fine I think.

@jonatanklosko
Copy link
Member Author

I noticed that Nx.tensor(0xFFFFFFFF) now crashes the VM with:

libc++abi: terminating due to uncaught exception of type std::runtime_error: value cannot be converted to type int without overflow

It was already the case before with Nx.s32(0xFFFFFFFF), it's just that now it's more likely to happen by default.

Perhaps there's a way to catch it and raise an elixir error instead, but that's not related to this PR.

@jonatanklosko
Copy link
Member Author

@polvalente feel free to merge, if you are ok with the torchx changes :)

@@ -220,7 +220,7 @@ defmodule EXLA do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXLA has a few intermediate tensors being built as s64 (see Value.eigh for example). We should also check if those can be changed to s32 as well. Not a blocker for this PR.

Copy link
Member Author

@jonatanklosko jonatanklosko Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think those are unrelated, we just want to pass fixed integers as XLA inputs. We could actually make those unsigned (and change to uint on the c++ side), since all of those are non-negative sizes. Doesn't matter much in this case, your call!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened a PR to change to u64 #1526.

Comment on lines -802 to -803
defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)}
defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the broadcasting explicitly in StableHLO too, so that's fine I think.

@@ -65,7 +65,7 @@ defmodule Torchx.Nx.RandomTest do
# Output does not match Nx because of the sign of the remainder.
distribution_case(:randint_split,
args: [0, 10, [shape: {5}]],
expected: Nx.tensor([3, 2, 6, 0, 0], type: :s64)
expected: Nx.tensor([1, 1, 4, 1, 9], type: :s64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't really understand why these values changed, but that's ok.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because in args: [0, 10, [shape: {5}]] passed to defn the numbers 0 and 10 are now passed as s32 rather than s64. We could maintain the behaviour with args: [Nx.s64(0), Nx.s64(10), [shape: {5}]], but changing the test is equally fine.

@polvalente polvalente merged commit 3ef1c7a into main Sep 3, 2024
8 checks passed
@polvalente polvalente deleted the jk-s32 branch September 3, 2024 20:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants