Skip to content

torch.nan_to_num doesn't work with -inf/inf #8674

@Akshat-Tripathi

Description

@Akshat-Tripathi

🐛 Bug

Hi, I was working on some pytorch code that's designed to run on multiple backends, including TPU via xla. This code uses the torch.nan_to_num() function with infinities, which works on pytorch's CPU and GPU backend, but not on TPU.

To Reproduce

import torch
import torch_xla.core.xla_model as xm

xla_device = xm.xla_device()
cpu_device = torch.device("cpu")

neg_inf = float("-inf")
pos_inf = float("inf")

xla_tensor = torch.zeros(3, 3, device=xla_device)
cpu_tensor = torch.zeros(3, 3, device=cpu_device)

try:
    torch.nan_to_num(
        xla_tensor,
        nan=neg_inf,
        posinf=pos_inf,
        neginf=neg_inf
    )
except:
    print("Failed")

torch.nan_to_num(
    cpu_tensor,
    nan=neg_inf,
    posinf=pos_inf,
    neginf=neg_inf
)
print("Passed")

Steps to reproduce the behavior:

  1. Run the above code snippet.

This is the stack trace I get with the xla backend.

[rank0]: RuntimeError: torch_xla/csrc/aten_xla_type.cpp:2346 : Check failed: min_max.min.toDouble() <= replacement.toDouble() && replacement.toDouble() <= min_max.max.toDouble() 
[rank0]: *** Begin stack trace ***
[rank0]:        tsl::CurrentStackTrace()
[rank0]:        torch_xla::XLANativeFunctions::nan_to_num(at::Tensor const&, std::optional<double>, std::optional<double>, std::optional<double>)
[rank0]: 
[rank0]: 
[rank0]:        at::_ops::nan_to_num::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<double>, std::optional<double>, std::optional<double>)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>)
[rank0]:        torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyObject_FastCallDictTstate
[rank0]:        _PyObject_Call_Prepend
[rank0]: 
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyObject_FastCallDictTstate
[rank0]:        _PyObject_Call_Prepend
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyObject_FastCallDictTstate
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        PyObject_Call
[rank0]:        _PyEval_EvalFrameDefault
[rank0]:        _PyFunction_Vectorcall
[rank0]:        _PyObject_FastCallDictTstate
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        PyEval_EvalCode
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyRun_SimpleFileObject
[rank0]:        _PyRun_AnyFileObject
[rank0]:        Py_RunMain
[rank0]:        Py_BytesMain
[rank0]: 
[rank0]:        __libc_start_main
[rank0]: 
[rank0]: *** End stack trace ***
[rank0]: Type BFloat16 replacement value -inf must be in the range [-3.40282e+38, 3.40282e+38].

Expected behavior

The function should work the same way on all torch backends.

Environment

  • Reproducible on XLA backend TPU:
  • torch_xla version: 2.6.0.dev20241126

Additional context

As a workaround I'm able to use dtype.min/max in place of infinities, but it's still not ideal.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingpytorch divergenceXLA behavior doesn't match Pytorch eager frontend

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions