Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the MulNoNan op. #210

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,46 @@ def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray:
return _func


def _mul_no_nan_forward(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Returns zero if y is zero, even if x if infinite or NaN."""
mul = anp.multiply(x, y)
return jnp.where(y == 0.0, jnp.zeros_like(mul), mul)


@register_operation("MulNoNan")
def _mul_no_nan(proto):
"""Parse a MulNoNan op."""
_check_attrs(proto, {"T"})

@jax.custom_gradient
def _func(
x: jnp.ndarray, y: jnp.ndarray
) -> Tuple[
jnp.ndarray,
Callable[[jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]],
]:
def _grad(g: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Given upstream grad G and a Div op: Z = X*Y, the gradients are.

dX = G * Y
dY = X * G

Args:
g: Upstream gradient.

Returns:
forward value: mul_no_nan(x, y)
grad: Gradient information in TF format.
"""
dx = _mul_no_nan_forward(g, y)
dy = _mul_no_nan_forward(x, g)
return dx, dy

return _mul_no_nan_forward(x, y), _grad

return _func


@register_operation("OneHot")
def _one_hot(proto):
"""Parse a OneHot Op."""
Expand Down
31 changes: 31 additions & 0 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,6 +2276,37 @@ def div_no_nan(inputs):
with self.subTest("check_backward_pass"):
self.assertAllClose(jax_gradient, np.asarray(tf_gradient))

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
(2., 2.), (0., 0.), (np.nan, 0.), (np.inf, 0.), (0., np.nan), (0., np.inf)
)
def test_mul_no_nan(self, x, y):

@tf.function
def mul_no_nan(inputs):
x, y = inputs
return tf.raw_ops.MulNoNan(x=x, y=y)

x = np.array(x)
y = np.array(y)
tf_x = tf.convert_to_tensor(x)
tf_y = tf.convert_to_tensor(y)

with tf.GradientTape() as g:
g.watch(tf_x)
g.watch(tf_y)
output = mul_no_nan([tf_x, tf_y])
tf_gradient = g.gradient(output, [tf_x, tf_y])

jax_func = tf2jax.convert_functional(mul_no_nan, [x, y])
jax_func = self.variant(jax_func)
jax_gradient = jax.grad(jax_func)([x, y])

with self.subTest("check_forward_pass"):
self.assertAllClose(jax_func([x, y]), np.asarray(output))
with self.subTest("check_backward_pass"):
self.assertAllClose(jax_gradient, np.asarray(tf_gradient))

@chex.variants(with_jit=True, without_jit=True)
def test_angle(self):
inputs = np.array([-2.25 + 4.75j, 3.25 + 5.75j], dtype=np.csingle)
Expand Down
Loading