Skip to content

Commit

Permalink
Reimplement ReLu6 in the backends and frontends (ivy-llc#10587)
Browse files Browse the repository at this point in the history
Co-authored-by: CatB1t <skytedits@gmail.com>
  • Loading branch information
MahmoudAshraf97 and CatB1t authored Feb 25, 2023
1 parent 30c9e49 commit 757a0e0
Show file tree
Hide file tree
Showing 13 changed files with 344 additions and 4 deletions.
47 changes: 47 additions & 0 deletions ivy/array/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,50 @@ def prelu(
-------
"""
return ivy.prelu(self._data, slope, out=out)

def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
"""Applies the rectified linear unit 6 function element-wise.
Parameters
----------
x
input array
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
Returns
-------
ret
an array containing the rectified linear unit 6 activation
of each element in ``x``.
Examples
--------
With :class:`ivy.Array` input:
>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.relu6(x)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.zeros(9)
>>> ivy.relu6(x, out = y)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
With :class:`ivy.Container` input:
>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> x = ivy.relu6(x, out=x)
>>> print(x)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}
"""
return ivy.relu6(self._data, out=out)
128 changes: 128 additions & 0 deletions ivy/container/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,131 @@ def prelu(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def static_relu6(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.relu6.
This method simply wraps the function, and so the docstring
for ivy.relu6 also applies to this method with minimal changes.
Parameters
----------
x
input container.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
a container with the rectified linear 6 activation unit function
applied element-wise.
Examples
--------
>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> y = ivy.Container.static_relu6(x)
>>> print(y)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}
"""
return ContainerBase.cont_multi_map_in_function(
"relu6",
x,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def relu6(
self: ivy.Container,
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.relu6.
This method simply wraps the function, and so the docstring
for ivy.relu6 also applies to this method with minimal changes.
Parameters
----------
self
input container.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
a container with the rectified linear 6 activation unit function
applied element-wise.
Examples
--------
>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> y = x.relu()
>>> print(y)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}
"""
return self.static_relu6(
self,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)
17 changes: 17 additions & 0 deletions ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Optional, Union

# global
import jax
import jax.numpy as jnp
from ivy.functional.backends.jax import JaxArray
from jax import lax
import ivy


def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None):
Expand All @@ -13,6 +16,20 @@ def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None):
return jnp.log(x / (1 - x))


def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
relu6_func = jax.nn.relu6

# sets gradient at 0 and 6 to 0 instead of 0.5
# can refactor to jax.nn.relu6 when this PR is merged
# https://github.com/google/jax/pull/14682
def custom_grad_func(x_and_grad, one): return lax.select(
(6 > x_and_grad[0]) & (x_and_grad[0] > 0), one, lax.full_like(one, 0))

new_func = ivy.bind_custom_gradient_function(relu6_func, custom_grad_func)

return new_func(x).astype(x.dtype)


def thresholded_relu(
x: JaxArray,
/,
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ def thresholded_relu(


thresholded_relu.support_native_out = True


@_scalar_output_to_0d_array
def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype)


relu6.support_native_out = True
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ def thresholded_relu(
out: Optional[Tensor] = None,
) -> Tensor:
return tf.where(x > threshold, x, 0)


@with_unsupported_dtypes({"2.9.1 and below": ("complex",)}, backend_version)
def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
return tf.nn.relu6(x)
7 changes: 7 additions & 0 deletions ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ def thresholded_relu(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.threshold(x, threshold=threshold, value=0)


@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, backend_version)
def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.nn.functional.relu6(x)

relu6.unsupported_dtypes = ("float16", "bfloat16",)
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def relu(x):

@to_ivy_arrays_and_back
def relu6(x):
res = ivy.minimum(ivy.maximum(x, 0.0), 6.0)
res = ivy.relu6(x)
return _type_conversion_64(res)


Expand Down
5 changes: 5 additions & 0 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ def relu(features, name=None):
return ivy.relu(features)


@to_ivy_arrays_and_back
def relu6(features, name=None):
return ivy.relu6(features)


@to_ivy_arrays_and_back
def softmax(logits, axis=None, name=None):
return ivy.softmax(logits, axis=axis)
Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,12 @@ def Pow(*, x, y, name="Pow"):
return ivy.pow(x, y)


Relu6 = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.nn.relu6))
Relu6 = to_ivy_arrays_and_back(
map_raw_ops_alias(
tf_frontend.nn.relu6,
kwargs_to_update={"x": "features"},
)
)


Sigmoid = to_ivy_arrays_and_back(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def threshold_(input, threshold, value):


def relu6(input, inplace=False):
ret = ivy.minimum(ivy.maximum(input, 0), 6)
ret = ivy.relu6(input)
if inplace:
ivy.inplace_update(input, ret)
return input
Expand Down Expand Up @@ -307,7 +307,7 @@ def leaky_relu_(input, negative_slope=0.01):


def hardswish(input, inplace=False):
relu6_val = ivy.minimum(ivy.maximum(ivy.add(input, 3), 0), 6)
relu6_val = ivy.relu6(ivy.add(input, 3))
ret = ivy.multiply(input, ivy.divide(relu6_val, 6))
if inplace:
ivy.inplace_update(input, ret)
Expand Down
57 changes: 57 additions & 0 deletions ivy/functional/ivy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ivy.utils.backend import current_backend
from ivy.utils.exceptions import handle_exceptions
from ivy.func_wrapper import (
handle_array_function,
handle_nestable,
to_native_arrays_and_back,
handle_array_like_without_promotion,
Expand Down Expand Up @@ -173,3 +174,59 @@ def thresholded_relu(
}
"""
return current_backend(x).thresholded_relu(x, threshold=threshold, out=out)


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
@handle_array_function
def relu6(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
"""Applies the rectified linear unit 6 function element-wise.
Parameters
----------
x
input array
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Returns
-------
ret
an array containing the rectified linear unit 6 activation of each element in
``x``.
Examples
--------
With :class:`ivy.Array` input:
>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.relu6(x)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.zeros(9)
>>> ivy.relu6(x, out = y)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
With :class:`ivy.Container` input:
>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> x = ivy.relu6(x, out=x)
>>> print(x)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}
"""
return current_backend(x).relu6(x, out=out)
Loading

0 comments on commit 757a0e0

Please sign in to comment.