diff --git a/tensorflow_addons/activations/mish.py b/tensorflow_addons/activations/mish.py index 2a5e5a8f05..e020cb0276 100644 --- a/tensorflow_addons/activations/mish.py +++ b/tensorflow_addons/activations/mish.py @@ -17,6 +17,7 @@ from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO +from tensorflow_addons import options _activation_so = LazySO("custom_ops/activations/_activation_ops.so") @@ -36,6 +37,17 @@ def mish(x: types.TensorLike) -> tf.Tensor: A `Tensor`. Has the same type as `x`. """ x = tf.convert_to_tensor(x) + + if not options.TF_ADDONS_PY_OPS: + try: + return _mish_custom_op(x) + except tf.errors.NotFoundError: + options.warn_fallback("mish") + + return _mish_custom_op(x) + + +def _mish_custom_op(x): return _activation_so.ops.addons_mish(x) diff --git a/tensorflow_addons/activations/softshrink.py b/tensorflow_addons/activations/softshrink.py index 238cc19036..861c35dcad 100644 --- a/tensorflow_addons/activations/softshrink.py +++ b/tensorflow_addons/activations/softshrink.py @@ -18,6 +18,7 @@ from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO +from tensorflow_addons import options _activation_so = LazySO("custom_ops/activations/_activation_ops.so") @@ -40,6 +41,17 @@ def softshrink( A `Tensor`. Has the same type as `x`. """ x = tf.convert_to_tensor(x) + + if not options.TF_ADDONS_PY_OPS: + try: + return _softshrink_custom_op(x, lower, upper) + except tf.errors.NotFoundError: + options.warn_fallback("softshrink") + + return _softshrink_py(x, lower, upper) + + +def _softshrink_custom_op(x, lower, upper): return _activation_so.ops.addons_softshrink(x, lower, upper) diff --git a/tensorflow_addons/activations/softshrink_test.py b/tensorflow_addons/activations/softshrink_test.py index 522bc0dc11..f18353804c 100644 --- a/tensorflow_addons/activations/softshrink_test.py +++ b/tensorflow_addons/activations/softshrink_test.py @@ -18,7 +18,10 @@ import numpy as np import tensorflow as tf from tensorflow_addons.activations import softshrink -from tensorflow_addons.activations.softshrink import _softshrink_py +from tensorflow_addons.activations.softshrink import ( + _softshrink_py, + _softshrink_custom_op, +) from tensorflow_addons.utils import test_utils @@ -26,7 +29,7 @@ class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase): def test_invalid(self): with self.assertRaisesOpError("lower must be less than or equal to upper."): - y = softshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) + y = _softshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) self.evaluate(y) @parameterized.named_parameters(