Skip to content

Commit

Permalink
Add nan_to_num helper (#796)
Browse files Browse the repository at this point in the history
As well as numpy-like posinf and neginf
  • Loading branch information
Dhruvanshu-Joshi authored Jul 7, 2024
1 parent 05d376f commit ca10298
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
75 changes: 75 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,22 @@ def largest(*args):
return max(stack(args), axis=0)


def isposinf(x):
"""
Return if the input variable has positive infinity element
"""
return eq(x, np.inf)


def isneginf(x):
"""
Return if the input variable has negative infinity element
"""
return eq(x, -np.inf)


@scalar_elemwise
def lt(a, b):
"""a < b"""
Expand Down Expand Up @@ -2913,6 +2929,62 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
return vectorize_node_fallback(op, node, batched_x, batched_y)


def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"""
Replace NaN with zero and infinity with large finite numbers (default
behaviour) or with the numbers defined by the user using the `nan`,
`posinf` and/or `neginf` keywords.
NaN is replaced by zero or by the user defined value in
`nan` keyword, infinity is replaced by the largest finite floating point
values representable by ``x.dtype`` or by the user defined value in
`posinf` keyword and -infinity is replaced by the most negative finite
floating point values representable by ``x.dtype`` or by the user defined
value in `neginf` keyword.
Parameters
----------
x : symbolic tensor
Input array.
nan
The value to replace NaN's with in the tensor (default = 0).
posinf
The value to replace +INF with in the tensor (default max
in range representable by ``x.dtype``).
neginf
The value to replace -INF with in the tensor (default min
in range representable by ``x.dtype``).
Returns
-------
out
The tensor with NaN's, +INF, and -INF replaced with the
specified and/or default substitutions.
"""
# Replace NaN's with nan keyword
is_nan = isnan(x)
is_pos_inf = isposinf(x)
is_neg_inf = isneginf(x)

x = switch(is_nan, nan, x)

# Get max and min values representable by x.dtype
maxf = posinf
minf = neginf

# Specify the value to replace +INF and -INF with
if maxf is None:
maxf = np.finfo(x.real.dtype).max
if minf is None:
minf = np.finfo(x.real.dtype).min

# Replace +INF and -INF values
x = switch(is_pos_inf, maxf, x)
x = switch(is_neg_inf, minf, x)

return x


# NumPy logical aliases
square = sqr

Expand Down Expand Up @@ -2951,6 +3023,8 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
"not_equal",
"isnan",
"isinf",
"isposinf",
"isneginf",
"allclose",
"isclose",
"and_",
Expand Down Expand Up @@ -3069,4 +3143,5 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
"logaddexp",
"logsumexp",
"hyp2f1",
"nan_to_num",
]
46 changes: 46 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
isinf,
isnan,
isnan_,
isneginf,
isposinf,
log,
log1mexp,
log1p,
Expand All @@ -96,6 +98,7 @@
minimum,
mod,
mul,
nan_to_num,
neg,
neq,
outer,
Expand Down Expand Up @@ -3689,3 +3692,46 @@ def test_grad_n_undefined(self):
n = scalar(dtype="int64")
with pytest.raises(NullTypeGradError):
grad(polygamma(n, 0.5), wrt=n)


def test_infs():
x = tensor(shape=(7,))

f_pos = function([x], isposinf(x))
f_neg = function([x], isneginf(x))

y = np.array([1, np.inf, 2, np.inf, -np.inf, -np.inf, 4]).astype(x.dtype)
out_pos = f_pos(y)
out_neg = f_neg(y)

np.testing.assert_allclose(
out_pos,
[0, 1, 0, 1, 0, 0, 0],
)
np.testing.assert_allclose(
out_neg,
[0, 0, 0, 0, 1, 1, 0],
)


@pytest.mark.parametrize(
["nan", "posinf", "neginf"],
[(0, None, None), (0, 0, 0), (0, None, 1000), (3, 1, -1)],
)
def test_nan_to_num(nan, posinf, neginf):
x = tensor(shape=(7,))

out = nan_to_num(x, nan, posinf, neginf)

f = function([x], out)

y = np.array([1, 2, np.nan, np.inf, -np.inf, 3, 4]).astype(x.dtype)
out = f(y)

posinf = np.finfo(x.real.dtype).max if posinf is None else posinf
neginf = np.finfo(x.real.dtype).min if neginf is None else neginf

np.testing.assert_allclose(
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)

0 comments on commit ca10298

Please sign in to comment.