Skip to content

add softshrink kernel #570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ py_library(
"gelu.py",
"hardshrink.py",
"lisht.py",
"softshrink.py",
"sparsemax.py",
"tanhshrink.py",
],
Expand Down Expand Up @@ -84,6 +85,19 @@ py_test(
],
)

py_test(
name = "softshrink_test",
size = "small",
srcs = [
"softshrink_test.py",
],
main = "softshrink_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)

py_test(
name = "tanhshrink_test",
size = "small",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
| gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com |
| hardshrink| @WindQAQ | windqaq@gmail.com |
| lisht | @WindQAQ | windqaq@gmail.com |
| softshrink| @WindQAQ | windqaq@gmail.com |
| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com |
| tanhshrink| @fsx950223 | fsx950223@gmail.com |

Expand All @@ -15,6 +16,7 @@
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
| hardshrink| hardshrink | |
| lisht | lisht | https://arxiv.org/abs/1901.05894 |
| softshrink| softshrink | |
| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 |
| tanhshrink| tanhshrink | |

Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from tensorflow_addons.activations.gelu import gelu
from tensorflow_addons.activations.hardshrink import hardshrink
from tensorflow_addons.activations.lisht import lisht
from tensorflow_addons.activations.softshrink import softshrink
from tensorflow_addons.activations.sparsemax import sparsemax
from tensorflow_addons.activations.tanhshrink import tanhshrink
2 changes: 1 addition & 1 deletion tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink"
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "tanhshrink"
]

def test_serialization(self):
Expand Down
52 changes: 52 additions & 0 deletions tensorflow_addons/activations/softshrink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))


@keras_utils.register_keras_custom_object
@tf.function
def softshrink(x, lower=-1.0, upper=1.0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any intuition on defaulting to 1.0? The link you posted (which references pytorch) has a default of 0.5

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no... It's my mistake on both hardshrink and softshrink though I could not find any research paper about the value 0.5. Will change them to 0.5 later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Also couldn't find any rationale for 0.5 other than framework defaults.

"""Soft shrink function.

Computes soft shrink function:
`x - lower if x < lower, x - upper if x > upper else 0`.

Args:
x: A `Tensor`. Must be one of the following types:
`float16`, `float32`, `float64`.
lower: `float`, lower bound for setting values to zeros.
upper: `float`, upper bound for setting values to zeros.
Returns:
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_softshrink(x, lower, upper)


@tf.RegisterGradient("Addons>Softshrink")
def _softshrink_grad(op, grad):
return _activation_ops_so.addons_softshrink_grad(grad, op.inputs[0],
op.get_attr("lower"),
op.get_attr("upper"))
71 changes: 71 additions & 0 deletions tensorflow_addons/activations/softshrink_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import softshrink
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid(self):
with self.assertRaisesOpError(
"lower must be less than or equal to upper."): # pylint: disable=bad-continuation
y = softshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
self.evaluate(y)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_softshrink(self, dtype):
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype)
self.assertAllCloseAccordingToType(softshrink(x), expected_result)

expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype)
self.assertAllCloseAccordingToType(
softshrink(x, lower=-0.5, upper=0.5), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian

# Softshrink is not continuous at `lower` and `upper`.
# Avoid these two points to make gradients smooth.
x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(softshrink, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
fn = softshrink.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), softshrink(x))


if __name__ == "__main__":
tf.test.main()
26 changes: 26 additions & 0 deletions tensorflow_addons/custom_ops/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "softshrink_op_gpu",
srcs = [
"cc/kernels/softshrink_op.h",
"cc/kernels/softshrink_op_gpu.cu.cc",
],
copts = if_cuda_is_configured([
"-DGOOGLE_CUDA=1",
"-x cuda",
"-nvcc_options=relaxed-constexpr",
"-nvcc_options=ftz=true",
]),
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudart_static",
]),
alwayslink = 1,
)

cc_library(
name = "tanhshrink_op_gpu",
srcs = [
Expand Down Expand Up @@ -102,11 +124,14 @@ cc_binary(
"cc/kernels/hardshrink_op.h",
"cc/kernels/lisht_op.cc",
"cc/kernels/lisht_op.h",
"cc/kernels/softshrink_op.cc",
"cc/kernels/softshrink_op.h",
"cc/kernels/tanhshrink_op.cc",
"cc/kernels/tanhshrink_op.h",
"cc/ops/gelu_op.cc",
"cc/ops/hardshrink_op.cc",
"cc/ops/lisht_op.cc",
"cc/ops/softshrink_op.cc",
"cc/ops/tanhshrink_op.cc",
],
copts = [
Expand All @@ -122,6 +147,7 @@ cc_binary(
":gelu_op_gpu",
":hardshrink_op_gpu",
":lisht_op_gpu",
":softshrink_op_gpu",
":tanhshrink_op_gpu",
]),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {

using CPUDevice = Eigen::ThreadPoolDevice;

#define REGISTER_SOFTSHRINK_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Softshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftshrinkOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("Addons>SoftshrinkGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
SoftshrinkGradOp<CPUDevice, type>);

// Softshrink only makes sense with floating points.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SOFTSHRINK_KERNELS);
#undef REGISTER_SOFTSHRINK_KERNELS

#if GOOGLE_CUDA

using GPUDevice = Eigen::GpuDevice;

// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Softshrink<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, T lower, \
T upper, typename TTypes<T>::Tensor activations); \
extern template struct Softshrink<GPUDevice, T>; \
\
template <> \
void SoftshrinkGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, T lower, T upper, \
typename TTypes<T>::Tensor backprops); \
extern template struct SoftshrinkGrad<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

// Registration of the GPU implementations.
#define REGISTER_SOFTSHRINK_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Softshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SoftshrinkOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("Addons>SoftshrinkGrad") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T"), \
SoftshrinkGradOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_SOFTSHRINK_GPU_KERNELS);
#undef REGISTER_SOFTSHRINK_GPU_KERNELS

#endif // GOOGLE_CUDA

} // namespace addons
} // namespace tensorflow
Loading