diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index cc296b1..ab5388a 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -1390,6 +1390,46 @@ def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: return _func +def _mul_no_nan_forward(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Returns zero if y is zero, even if x if infinite or NaN.""" + mul = anp.multiply(x, y) + return jnp.where(y == 0.0, jnp.zeros_like(mul), mul) + + +@register_operation("MulNoNan") +def _mul_no_nan(proto): + """Parse a MulNoNan op.""" + _check_attrs(proto, {"T"}) + + @jax.custom_gradient + def _func( + x: jnp.ndarray, y: jnp.ndarray + ) -> Tuple[ + jnp.ndarray, + Callable[[jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]], + ]: + def _grad(g: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Given upstream grad G and a Div op: Z = X*Y, the gradients are. + + dX = G * Y + dY = X * G + + Args: + g: Upstream gradient. + + Returns: + forward value: mul_no_nan(x, y) + grad: Gradient information in TF format. + """ + dx = _mul_no_nan_forward(g, y) + dy = _mul_no_nan_forward(x, g) + return dx, dy + + return _mul_no_nan_forward(x, y), _grad + + return _func + + @register_operation("OneHot") def _one_hot(proto): """Parse a OneHot Op.""" diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index aa4e467..ee7fa32 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -2276,6 +2276,37 @@ def div_no_nan(inputs): with self.subTest("check_backward_pass"): self.assertAllClose(jax_gradient, np.asarray(tf_gradient)) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.parameters( + (2., 2.), (0., 0.), (np.nan, 0.), (np.inf, 0.), (0., np.nan), (0., np.inf) + ) + def test_mul_no_nan(self, x, y): + + @tf.function + def mul_no_nan(inputs): + x, y = inputs + return tf.raw_ops.MulNoNan(x=x, y=y) + + x = np.array(x) + y = np.array(y) + tf_x = tf.convert_to_tensor(x) + tf_y = tf.convert_to_tensor(y) + + with tf.GradientTape() as g: + g.watch(tf_x) + g.watch(tf_y) + output = mul_no_nan([tf_x, tf_y]) + tf_gradient = g.gradient(output, [tf_x, tf_y]) + + jax_func = tf2jax.convert_functional(mul_no_nan, [x, y]) + jax_func = self.variant(jax_func) + jax_gradient = jax.grad(jax_func)([x, y]) + + with self.subTest("check_forward_pass"): + self.assertAllClose(jax_func([x, y]), np.asarray(output)) + with self.subTest("check_backward_pass"): + self.assertAllClose(jax_gradient, np.asarray(tf_gradient)) + @chex.variants(with_jit=True, without_jit=True) def test_angle(self): inputs = np.array([-2.25 + 4.75j, 3.25 + 5.75j], dtype=np.csingle)