diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index 5727213ec3..af015cef5f 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -9,6 +9,7 @@ py_library( "gelu.py", "hardshrink.py", "lisht.py", + "softshrink.py", "sparsemax.py", "tanhshrink.py", ], @@ -84,6 +85,19 @@ py_test( ], ) +py_test( + name = "softshrink_test", + size = "small", + srcs = [ + "softshrink_test.py", + ], + main = "softshrink_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) + py_test( name = "tanhshrink_test", size = "small", diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index c548a32923..14522f3bc6 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -6,6 +6,7 @@ | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | | hardshrink| @WindQAQ | windqaq@gmail.com | | lisht | @WindQAQ | windqaq@gmail.com | +| softshrink| @WindQAQ | windqaq@gmail.com | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tanhshrink| @fsx950223 | fsx950223@gmail.com | @@ -15,6 +16,7 @@ | gelu | gelu | https://arxiv.org/abs/1606.08415 | | hardshrink| hardshrink | | | lisht | lisht | https://arxiv.org/abs/1901.05894 | +| softshrink| softshrink | | | sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | | tanhshrink| tanhshrink | | diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index 313a78a1e3..ba9d6a3738 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -21,5 +21,6 @@ from tensorflow_addons.activations.gelu import gelu from tensorflow_addons.activations.hardshrink import hardshrink from tensorflow_addons.activations.lisht import lisht +from tensorflow_addons.activations.softshrink import softshrink from tensorflow_addons.activations.sparsemax import sparsemax from tensorflow_addons.activations.tanhshrink import tanhshrink diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py index 31a4b82196..d685e7d2ca 100644 --- a/tensorflow_addons/activations/activations_test.py +++ b/tensorflow_addons/activations/activations_test.py @@ -26,7 +26,7 @@ class ActivationsTest(tf.test.TestCase): ALL_ACTIVATIONS = [ - "gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink" + "gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "tanhshrink" ] def test_serialization(self): diff --git a/tensorflow_addons/activations/softshrink.py b/tensorflow_addons/activations/softshrink.py new file mode 100644 index 0000000000..11cfe51467 --- /dev/null +++ b/tensorflow_addons/activations/softshrink.py @@ -0,0 +1,52 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def softshrink(x, lower=-1.0, upper=1.0): + """Soft shrink function. + + Computes soft shrink function: + `x - lower if x < lower, x - upper if x > upper else 0`. + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + lower: `float`, lower bound for setting values to zeros. + upper: `float`, upper bound for setting values to zeros. + Returns: + A `Tensor`. Has the same type as `x`. + """ + x = tf.convert_to_tensor(x) + return _activation_ops_so.addons_softshrink(x, lower, upper) + + +@tf.RegisterGradient("Addons>Softshrink") +def _softshrink_grad(op, grad): + return _activation_ops_so.addons_softshrink_grad(grad, op.inputs[0], + op.get_attr("lower"), + op.get_attr("upper")) diff --git a/tensorflow_addons/activations/softshrink_test.py b/tensorflow_addons/activations/softshrink_test.py new file mode 100644 index 0000000000..9fadd12428 --- /dev/null +++ b/tensorflow_addons/activations/softshrink_test.py @@ -0,0 +1,71 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import softshrink +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase): + def test_invalid(self): + with self.assertRaisesOpError( + "lower must be less than or equal to upper."): # pylint: disable=bad-continuation + y = softshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) + self.evaluate(y) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_softshrink(self, dtype): + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype) + self.assertAllCloseAccordingToType(softshrink(x), expected_result) + + expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype) + self.assertAllCloseAccordingToType( + softshrink(x, lower=-0.5, upper=0.5), expected_result) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + + # Softshrink is not continuous at `lower` and `upper`. + # Avoid these two points to make gradients smooth. + x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype) + + theoretical, numerical = tf.test.compute_gradient(softshrink, [x]) + self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + + def test_unknown_shape(self): + fn = softshrink.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), softshrink(x)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index e86802bdde..56b0f00d3b 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -71,6 +71,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "softshrink_op_gpu", + srcs = [ + "cc/kernels/softshrink_op.h", + "cc/kernels/softshrink_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudart_static", + ]), + alwayslink = 1, +) + cc_library( name = "tanhshrink_op_gpu", srcs = [ @@ -102,11 +124,14 @@ cc_binary( "cc/kernels/hardshrink_op.h", "cc/kernels/lisht_op.cc", "cc/kernels/lisht_op.h", + "cc/kernels/softshrink_op.cc", + "cc/kernels/softshrink_op.h", "cc/kernels/tanhshrink_op.cc", "cc/kernels/tanhshrink_op.h", "cc/ops/gelu_op.cc", "cc/ops/hardshrink_op.cc", "cc/ops/lisht_op.cc", + "cc/ops/softshrink_op.cc", "cc/ops/tanhshrink_op.cc", ], copts = [ @@ -122,6 +147,7 @@ cc_binary( ":gelu_op_gpu", ":hardshrink_op_gpu", ":lisht_op_gpu", + ":softshrink_op_gpu", ":tanhshrink_op_gpu", ]), ) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.cc new file mode 100644 index 0000000000..4c2fe353a8 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.cc @@ -0,0 +1,81 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_SOFTSHRINK_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Softshrink").Device(DEVICE_CPU).TypeConstraint("T"), \ + SoftshrinkOp); \ + REGISTER_KERNEL_BUILDER(Name("Addons>SoftshrinkGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + SoftshrinkGradOp); + +// Softshrink only makes sense with floating points. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SOFTSHRINK_KERNELS); +#undef REGISTER_SOFTSHRINK_KERNELS + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Softshrink::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, T lower, \ + T upper, typename TTypes::Tensor activations); \ + extern template struct Softshrink; \ + \ + template <> \ + void SoftshrinkGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, T lower, T upper, \ + typename TTypes::Tensor backprops); \ + extern template struct SoftshrinkGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_SOFTSHRINK_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Softshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ + SoftshrinkOp); \ + REGISTER_KERNEL_BUILDER(Name("Addons>SoftshrinkGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + SoftshrinkGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SOFTSHRINK_GPU_KERNELS); +#undef REGISTER_SOFTSHRINK_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.h new file mode 100644 index 0000000000..ae09d87db0 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.h @@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_SOFTSHRINK_OP_H_ +#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_SOFTSHRINK_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { + +namespace functor { + +// Functor used by SoftshrinkOp to do the computations. +template +struct Softshrink { + // Computes Softshrink activation. + // + // features: any shape. + // lower: the lower bound for setting values to zeros. + // upper: the upper bound for setting values to zeros. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + T lower, T upper, typename TTypes::Tensor activations) { + activations.device(d) = + (features < lower) + .select(features - features.constant(lower), + (features > upper) + .select(features - features.constant(upper), + features.constant(static_cast(0)))); + } +}; + +// Functor used by SoftshrinkGradOp to do the computations. +template +struct SoftshrinkGrad { + // Computes SoftshrinkGrad backprops. + // + // gradients: gradients backpropagated to the Softshink op. + // features: inputs that were passed to the Softshrink op. + // lower: the lower bound for setting values to zeros. + // upper: the upper bound for setting values to zeros. + // backprops: gradients to backpropagate to the Softshrink inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, T lower, T upper, + typename TTypes::Tensor backprops) { + backprops.device(d) = + (features < lower || features > upper) + .select(gradients, features.constant(static_cast(0))); + } +}; + +} // namespace functor + +template +class SoftshrinkOp : public UnaryElementWiseOp> { + public: + explicit SoftshrinkOp(OpKernelConstruction* context) + : UnaryElementWiseOp>::UnaryElementWiseOp( + context) { + float lower, upper; + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); + OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); + lower_ = static_cast(lower); + upper_ = static_cast(upper); + + OP_REQUIRES( + context, lower_ <= upper_, + errors::InvalidArgument("lower must be less than or equal to upper.")); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Softshrink functor; + functor(context->eigen_device(), input.flat(), lower_, upper_, + output->flat()); + } + + private: + T lower_; + T upper_; +}; + +template +class SoftshrinkGradOp + : public BinaryElementWiseOp> { + public: + explicit SoftshrinkGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp< + T, SoftshrinkGradOp>::BinaryElementWiseOp(context) { + float lower, upper; + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); + OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); + lower_ = static_cast(lower); + upper_ = static_cast(upper); + + OP_REQUIRES( + context, lower_ <= upper_, + errors::InvalidArgument("lower must be less than or equal to upper.")); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T lower, T upper, Tensor* output); + + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, lower_, upper_, output); + } + + private: + T lower_; + T upper_; +}; + +template +void SoftshrinkGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, + const Tensor& a, T lower, + T upper, Tensor* output) { + functor::SoftshrinkGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), lower, + upper, output->flat()); +} + +} // namespace addons +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_SOFTSHRINK_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op_gpu.cu.cc new file mode 100644 index 0000000000..d91c49453b --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op_gpu.cu.cc @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/softshrink_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Softshrink; \ + template struct functor::SoftshrinkGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/softshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/softshrink_op.cc new file mode 100644 index 0000000000..6a568ad565 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/softshrink_op.cc @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace addons { + +REGISTER_OP("Addons>Softshrink") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .Attr("lower: float = -1.0") + .Attr("upper: float = 1.0") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("Addons>SoftshrinkGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .Attr("lower: float = -1.0") + .Attr("upper: float = 1.0") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace addons +} // namespace tensorflow