Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement ReLu6 in the backends and frontends #10587

Merged
merged 13 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ivy/array/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,50 @@ def prelu(
-------
"""
return ivy.prelu(self._data, slope, out=out)

def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
"""Applies the rectified linear unit 6 function element-wise.

Parameters
----------
x
input array
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.

Returns
-------
ret
an array containing the rectified linear unit 6 activation
of each element in ``x``.

Examples
--------
With :class:`ivy.Array` input:

>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.relu6(x)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])

>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.zeros(9)
>>> ivy.relu6(x, out = y)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])

With :class:`ivy.Container` input:

>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> x = ivy.relu6(x, out=x)
>>> print(x)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}
"""
return ivy.relu6(self._data, out=out)
128 changes: 128 additions & 0 deletions ivy/container/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,131 @@ def prelu(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def static_relu6(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.relu6.
This method simply wraps the function, and so the docstring
for ivy.relu6 also applies to this method with minimal changes.

Parameters
----------
x
input container.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.

Returns
-------
ret
a container with the rectified linear 6 activation unit function
applied element-wise.

Examples
--------
>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> y = ivy.Container.static_relu6(x)
>>> print(y)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}

"""
return ContainerBase.cont_multi_map_in_function(
"relu6",
x,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def relu6(
self: ivy.Container,
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.relu6.
This method simply wraps the function, and so the docstring
for ivy.relu6 also applies to this method with minimal changes.

Parameters
----------
self
input container.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.

Returns
-------
ret
a container with the rectified linear 6 activation unit function
applied element-wise.

Examples
--------
>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> y = x.relu()
>>> print(y)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}

"""
return self.static_relu6(
self,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)
17 changes: 17 additions & 0 deletions ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Optional, Union

# global
import jax
import jax.numpy as jnp
from ivy.functional.backends.jax import JaxArray
from jax import lax
import ivy


def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None):
Expand All @@ -13,6 +16,20 @@ def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None):
return jnp.log(x / (1 - x))


def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
relu6_func = jax.nn.relu6

# sets gradient at 0 and 6 to 0 instead of 0.5
# can refactor to jax.nn.relu6 when this PR is merged
# https://github.com/google/jax/pull/14682
def custom_grad_func(x_and_grad, one): return lax.select(
(6 > x_and_grad[0]) & (x_and_grad[0] > 0), one, lax.full_like(one, 0))

new_func = ivy.bind_custom_gradient_function(relu6_func, custom_grad_func)

return new_func(x).astype(x.dtype)


def thresholded_relu(
x: JaxArray,
/,
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ def thresholded_relu(


thresholded_relu.support_native_out = True


@_scalar_output_to_0d_array
def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype)


relu6.support_native_out = True
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ def thresholded_relu(
out: Optional[Tensor] = None,
) -> Tensor:
return tf.where(x > threshold, x, 0)


@with_unsupported_dtypes({"2.9.1 and below": ("complex",)}, backend_version)
def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
return tf.nn.relu6(x)
7 changes: 7 additions & 0 deletions ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ def thresholded_relu(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.threshold(x, threshold=threshold, value=0)


@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, backend_version)
def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.nn.functional.relu6(x)

relu6.unsupported_dtypes = ("float16", "bfloat16",)
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def relu(x):

@to_ivy_arrays_and_back
def relu6(x):
res = ivy.minimum(ivy.maximum(x, 0.0), 6.0)
res = ivy.relu6(x)
return _type_conversion_64(res)


Expand Down
5 changes: 5 additions & 0 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ def relu(features, name=None):
return ivy.relu(features)


@to_ivy_arrays_and_back
def relu6(features, name=None):
return ivy.relu6(features)


@to_ivy_arrays_and_back
def softmax(logits, axis=None, name=None):
return ivy.softmax(logits, axis=axis)
Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,12 @@ def Pow(*, x, y, name="Pow"):
return ivy.pow(x, y)


Relu6 = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.nn.relu6))
Relu6 = to_ivy_arrays_and_back(
map_raw_ops_alias(
tf_frontend.nn.relu6,
kwargs_to_update={"x": "features"},
)
)


Sigmoid = to_ivy_arrays_and_back(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def threshold_(input, threshold, value):


def relu6(input, inplace=False):
ret = ivy.minimum(ivy.maximum(input, 0), 6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this one also fails for JAX backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problem as other two tests

ret = ivy.relu6(input)
if inplace:
ivy.inplace_update(input, ret)
return input
Expand Down Expand Up @@ -307,7 +307,7 @@ def leaky_relu_(input, negative_slope=0.01):


def hardswish(input, inplace=False):
relu6_val = ivy.minimum(ivy.maximum(ivy.add(input, 3), 0), 6)
relu6_val = ivy.relu6(ivy.add(input, 3))
ret = ivy.multiply(input, ivy.divide(relu6_val, 6))
if inplace:
ivy.inplace_update(input, ret)
Expand Down
57 changes: 57 additions & 0 deletions ivy/functional/ivy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ivy.utils.backend import current_backend
from ivy.utils.exceptions import handle_exceptions
from ivy.func_wrapper import (
handle_array_function,
handle_nestable,
to_native_arrays_and_back,
handle_array_like_without_promotion,
Expand Down Expand Up @@ -173,3 +174,59 @@ def thresholded_relu(
}
"""
return current_backend(x).thresholded_relu(x, threshold=threshold, out=out)


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
@handle_array_function
def relu6(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
"""Applies the rectified linear unit 6 function element-wise.

Parameters
----------
x
input array
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.

Returns
-------
ret
an array containing the rectified linear unit 6 activation of each element in
``x``.

Examples
--------
With :class:`ivy.Array` input:

>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.relu6(x)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])

>>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> y = ivy.zeros(9)
>>> ivy.relu6(x, out = y)
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])

With :class:`ivy.Container` input:

>>> x = {
a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
}
>>> x = ivy.relu6(x, out=x)
>>> print(x)
{
a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]),
b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.])
}
"""
return current_backend(x).relu6(x, out=out)
Loading