- 
                Notifications
    You must be signed in to change notification settings 
- Fork 559
Open
Labels
bugSomething isn't workingSomething isn't workingpytorch divergenceXLA behavior doesn't match Pytorch eager frontendXLA behavior doesn't match Pytorch eager frontend
Description
🐛 Bug
Hi, I was working on some pytorch code that's designed to run on multiple backends, including TPU via xla. This code uses the torch.nan_to_num() function with infinities, which works on pytorch's CPU and GPU backend, but not on TPU.
To Reproduce
import torch
import torch_xla.core.xla_model as xm
xla_device = xm.xla_device()
cpu_device = torch.device("cpu")
neg_inf = float("-inf")
pos_inf = float("inf")
xla_tensor = torch.zeros(3, 3, device=xla_device)
cpu_tensor = torch.zeros(3, 3, device=cpu_device)
try:
    torch.nan_to_num(
        xla_tensor,
        nan=neg_inf,
        posinf=pos_inf,
        neginf=neg_inf
    )
except:
    print("Failed")
torch.nan_to_num(
    cpu_tensor,
    nan=neg_inf,
    posinf=pos_inf,
    neginf=neg_inf
)
print("Passed")Steps to reproduce the behavior:
- Run the above code snippet.
This is the stack trace I get with the xla backend.
[rank0]: RuntimeError: torch_xla/csrc/aten_xla_type.cpp:2346 : Check failed: min_max.min.toDouble() <= replacement.toDouble() && replacement.toDouble() <= min_max.max.toDouble() 
[rank0]: *** Begin stack trace ***
[rank0]:        tsl::CurrentStackTrace()
[rank0]:        torch_xla::XLANativeFunctions::nan_to_num(at::Tensor const&, std::optional<double>, std::optional<double>, std::optional<double>)
[rank0]: 
[rank0]: 
[rank0]:        at::_ops::nan_to_num::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<double>, std::optional<double>, std::optional<double>)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>)
[rank0]:        torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyObject_FastCallDictTstate
[rank0]:        _PyObject_Call_Prepend
[rank0]: 
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyObject_FastCallDictTstate
[rank0]:        _PyObject_Call_Prepend
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyObject_FastCallDictTstate
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyObject_FastCallDictTstate
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        PyEval_EvalCode
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyRun_SimpleFileObject
[rank0]:        _PyRun_AnyFileObject
[rank0]:        Py_RunMain
[rank0]:        Py_BytesMain
[rank0]: 
[rank0]:        __libc_start_main
[rank0]: 
[rank0]: *** End stack trace ***
[rank0]: Type BFloat16 replacement value -inf must be in the range [-3.40282e+38, 3.40282e+38].
Expected behavior
The function should work the same way on all torch backends.
Environment
- Reproducible on XLA backend TPU:
- torch_xla version: 2.6.0.dev20241126
Additional context
As a workaround I'm able to use dtype.min/max in place of infinities, but it's still not ideal.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingpytorch divergenceXLA behavior doesn't match Pytorch eager frontendXLA behavior doesn't match Pytorch eager frontend