diff --git a/ivy/array/experimental/activations.py b/ivy/array/experimental/activations.py index 527fb0af9a1ab..61d8882062515 100644 --- a/ivy/array/experimental/activations.py +++ b/ivy/array/experimental/activations.py @@ -109,3 +109,50 @@ def prelu( ------- """ return ivy.prelu(self._data, slope, out=out) + + def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + """Applies the rectified linear unit 6 function element-wise. + + Parameters + ---------- + x + input array + out + optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. + + Returns + ------- + ret + an array containing the rectified linear unit 6 activation + of each element in ``x``. + + Examples + -------- + With :class:`ivy.Array` input: + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.relu6(x) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.zeros(9) + >>> ivy.relu6(x, out = y) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + With :class:`ivy.Container` input: + + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> x = ivy.relu6(x, out=x) + >>> print(x) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + """ + return ivy.relu6(self._data, out=out) diff --git a/ivy/container/experimental/activations.py b/ivy/container/experimental/activations.py index d6f7ce2230431..c23a9e526430f 100644 --- a/ivy/container/experimental/activations.py +++ b/ivy/container/experimental/activations.py @@ -323,3 +323,131 @@ def prelu( map_sequences=map_sequences, out=out, ) + + @staticmethod + def static_relu6( + x: Union[ivy.Array, ivy.NativeArray, ivy.Container], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.relu6. + This method simply wraps the function, and so the docstring + for ivy.relu6 also applies to this method with minimal changes. + + Parameters + ---------- + x + input container. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + a container with the rectified linear 6 activation unit function + applied element-wise. + + Examples + -------- + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> y = ivy.Container.static_relu6(x) + >>> print(y) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + + """ + return ContainerBase.cont_multi_map_in_function( + "relu6", + x, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def relu6( + self: ivy.Container, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.relu6. + This method simply wraps the function, and so the docstring + for ivy.relu6 also applies to this method with minimal changes. + + Parameters + ---------- + self + input container. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + a container with the rectified linear 6 activation unit function + applied element-wise. + + Examples + -------- + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> y = x.relu() + >>> print(y) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + + """ + return self.static_relu6( + self, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index b181680031741..5fce86fd64ff2 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -1,8 +1,11 @@ from typing import Optional, Union # global +import jax import jax.numpy as jnp from ivy.functional.backends.jax import JaxArray +from jax import lax +import ivy def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None): @@ -13,6 +16,20 @@ def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None): return jnp.log(x / (1 - x)) +def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: + relu6_func = jax.nn.relu6 + + # sets gradient at 0 and 6 to 0 instead of 0.5 + # can refactor to jax.nn.relu6 when this PR is merged + # https://github.com/google/jax/pull/14682 + def custom_grad_func(x_and_grad, one): return lax.select( + (6 > x_and_grad[0]) & (x_and_grad[0] > 0), one, lax.full_like(one, 0)) + + new_func = ivy.bind_custom_gradient_function(relu6_func, custom_grad_func) + + return new_func(x).astype(x.dtype) + + def thresholded_relu( x: JaxArray, /, diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index 33f431c6e3a39..c5b9f51954a69 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -31,3 +31,11 @@ def thresholded_relu( thresholded_relu.support_native_out = True + + +@_scalar_output_to_0d_array +def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: + return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype) + + +relu6.support_native_out = True diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index ea35b0d2a959f..98fef2a5d3810 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -30,3 +30,8 @@ def thresholded_relu( out: Optional[Tensor] = None, ) -> Tensor: return tf.where(x > threshold, x, 0) + + +@with_unsupported_dtypes({"2.9.1 and below": ("complex",)}, backend_version) +def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: + return tf.nn.relu6(x) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index ef3ff49ae66ca..f43301f45cee9 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -23,3 +23,10 @@ def thresholded_relu( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.threshold(x, threshold=threshold, value=0) + + +@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, backend_version) +def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.nn.functional.relu6(x) + +relu6.unsupported_dtypes = ("float16", "bfloat16",) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 376678264742c..debf121a89059 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -271,7 +271,7 @@ def relu(x): @to_ivy_arrays_and_back def relu6(x): - res = ivy.minimum(ivy.maximum(x, 0.0), 6.0) + res = ivy.relu6(x) return _type_conversion_64(res) diff --git a/ivy/functional/frontends/tensorflow/nn.py b/ivy/functional/frontends/tensorflow/nn.py index 3d01707dce2a3..5a0f55aadbcfa 100644 --- a/ivy/functional/frontends/tensorflow/nn.py +++ b/ivy/functional/frontends/tensorflow/nn.py @@ -435,6 +435,11 @@ def relu(features, name=None): return ivy.relu(features) +@to_ivy_arrays_and_back +def relu6(features, name=None): + return ivy.relu6(features) + + @to_ivy_arrays_and_back def softmax(logits, axis=None, name=None): return ivy.softmax(logits, axis=axis) diff --git a/ivy/functional/frontends/tensorflow/raw_ops.py b/ivy/functional/frontends/tensorflow/raw_ops.py index 3f2152c3ea750..2614ff173304a 100644 --- a/ivy/functional/frontends/tensorflow/raw_ops.py +++ b/ivy/functional/frontends/tensorflow/raw_ops.py @@ -485,7 +485,12 @@ def Pow(*, x, y, name="Pow"): return ivy.pow(x, y) -Relu6 = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.nn.relu6)) +Relu6 = to_ivy_arrays_and_back( + map_raw_ops_alias( + tf_frontend.nn.relu6, + kwargs_to_update={"x": "features"}, + ) +) Sigmoid = to_ivy_arrays_and_back( diff --git a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py index 014bc88c09a9b..9aea017339fa4 100644 --- a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py @@ -172,7 +172,7 @@ def threshold_(input, threshold, value): def relu6(input, inplace=False): - ret = ivy.minimum(ivy.maximum(input, 0), 6) + ret = ivy.relu6(input) if inplace: ivy.inplace_update(input, ret) return input @@ -307,7 +307,7 @@ def leaky_relu_(input, negative_slope=0.01): def hardswish(input, inplace=False): - relu6_val = ivy.minimum(ivy.maximum(ivy.add(input, 3), 0), 6) + relu6_val = ivy.relu6(ivy.add(input, 3)) ret = ivy.multiply(input, ivy.divide(relu6_val, 6)) if inplace: ivy.inplace_update(input, ret) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index a349c1ef3ff72..e00a50db42bab 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -6,6 +6,7 @@ from ivy.utils.backend import current_backend from ivy.utils.exceptions import handle_exceptions from ivy.func_wrapper import ( + handle_array_function, handle_nestable, to_native_arrays_and_back, handle_array_like_without_promotion, @@ -173,3 +174,59 @@ def thresholded_relu( } """ return current_backend(x).thresholded_relu(x, threshold=threshold, out=out) + + +@to_native_arrays_and_back +@handle_out_argument +@handle_nestable +@handle_exceptions +@handle_array_like_without_promotion +@handle_array_function +def relu6( + x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None +) -> ivy.Array: + """Applies the rectified linear unit 6 function element-wise. + + Parameters + ---------- + x + input array + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. + + Returns + ------- + ret + an array containing the rectified linear unit 6 activation of each element in + ``x``. + + Examples + -------- + With :class:`ivy.Array` input: + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.relu6(x) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.zeros(9) + >>> ivy.relu6(x, out = y) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + With :class:`ivy.Container` input: + + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> x = ivy.relu6(x, out=x) + >>> print(x) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + """ + return current_backend(x).relu6(x, out=out) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py index f79b826b5fbc2..dfb3f94b807ad 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py @@ -1105,6 +1105,36 @@ def test_tensorflow_relu( ) +# relu6 +@handle_frontend_test( + fn_tree="tensorflow.nn.relu6", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_value=-20, + max_value=20, + ), + test_with_out=st.just(False), +) +def test_tensorflow_relu6( + *, + dtype_and_x, + test_flags, + frontend, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + features=x[0], + ) + + # softmax @handle_frontend_test( fn_tree="tensorflow.nn.softmax", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index f49c7f49d40c4..c787a80de8643 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -106,3 +106,34 @@ def test_prelu( x=x[0], slope=slope, ) + + +# relu +@handle_test( + fn_tree="functional.ivy.experimental.relu6", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + ), +) +def test_relu6( + *, + dtype_and_x, + test_flags, + backend_fw, + fn_name, + on_device, + ground_truth_backend, +): + dtype, x = dtype_and_x + helpers.test_function( + ground_truth_backend=ground_truth_backend, + input_dtypes=dtype, + fw=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + x=x[0], + )