From 1af8077f9860d8bb45bc56e6c43a39930a2565fb Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Tue, 22 Feb 2022 23:10:04 +0800 Subject: [PATCH 01/10] support leaky_relu half kernel --- oneflow/user/kernels/leaky_relu_kernel.cu | 15 +++++++++++++++ python/oneflow/test/modules/test_activation.py | 11 +++++++++++ 2 files changed, 26 insertions(+) diff --git a/oneflow/user/kernels/leaky_relu_kernel.cu b/oneflow/user/kernels/leaky_relu_kernel.cu index 41b5aebc844..0ac8a7e61ca 100644 --- a/oneflow/user/kernels/leaky_relu_kernel.cu +++ b/oneflow/user/kernels/leaky_relu_kernel.cu @@ -31,8 +31,21 @@ __global__ void LeakyReluBackwardGpu(const int n, const float alpha, const T* x, CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = x[i] > 0 ? dy[i] : dy[i] * alpha; } } +template<> +__global__ void LeakyReluForwardGpu(const int n, const float alpha, const half* x, half* y) { + CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] > static_cast(0) ? x[i] : x[i] * static_cast(alpha); } +} + +template<> +__global__ void LeakyReluBackwardGpu(const int n, const float alpha, const half* x, const half* dy, + half* dx) { + CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = x[i] > static_cast(0) ? dy[i] : dy[i] * static_cast(alpha); } +} + } // namespace + + template class GpuLeakyReluKernel final : public user_op::OpKernel { public: @@ -58,6 +71,7 @@ class GpuLeakyReluKernel final : public user_op::OpKernel { .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("y", 0) == GetDataType::value)); +REGISTER_CUDA_LEAKY_RELU_KERNEL(half) REGISTER_CUDA_LEAKY_RELU_KERNEL(float) REGISTER_CUDA_LEAKY_RELU_KERNEL(double) @@ -87,6 +101,7 @@ class GpuLeakyReluGradKernel final : public user_op::OpKernel { .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value)); +REGISTER_CUDA_LEAKY_RELU_GRAD_KERNEL(half) REGISTER_CUDA_LEAKY_RELU_GRAD_KERNEL(float) REGISTER_CUDA_LEAKY_RELU_GRAD_KERNEL(double) diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index 1657eeb339d..8b8f31b83b4 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -587,6 +587,17 @@ def test_leakyrelu_module_with_random_data(test_case): y = m(x) return y + @autotest() + def test_leakyrelu_module_with_half_random_data(test_case): + m = torch.nn.LeakyReLU(negative_slope=random() | nothing()) + m.train(random()) + device = random_device() + m.to(device) + x = random_tensor().to(device) + x = x.to(torch.float16) + y = m(x) + return y + @autotest() def test_leakyrelu_module_with_0dim_data(test_case): m = torch.nn.LeakyReLU(negative_slope=random() | nothing()) From ae5acc88bc675714bcdcf7845c939fdabdff4b23 Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Wed, 23 Feb 2022 09:33:39 +0800 Subject: [PATCH 02/10] add half leaky relu --- oneflow/user/kernels/activation_kernels.cpp | 1 + oneflow/user/kernels/activation_kernels.cu | 23 +++++ oneflow/user/kernels/activation_kernels.h | 32 ++++++ oneflow/user/kernels/leaky_relu_kernel.cpp | 78 -------------- oneflow/user/kernels/leaky_relu_kernel.cu | 108 -------------------- 5 files changed, 56 insertions(+), 186 deletions(-) delete mode 100644 oneflow/user/kernels/leaky_relu_kernel.cpp delete mode 100644 oneflow/user/kernels/leaky_relu_kernel.cu diff --git a/oneflow/user/kernels/activation_kernels.cpp b/oneflow/user/kernels/activation_kernels.cpp index 5a33fb07980..b53006d9fd9 100644 --- a/oneflow/user/kernels/activation_kernels.cpp +++ b/oneflow/user/kernels/activation_kernels.cpp @@ -27,6 +27,7 @@ namespace oneflow { REGISTER_SILU_KERNEL(DeviceType::kCPU, dtype); \ REGISTER_SELU_KERNEL(DeviceType::kCPU, dtype); \ REGISTER_SOFTSIGN_KERNEL(DeviceType::kCPU, dtype); \ + REGISTER_LEAKYRELU_KERNEL(DeviceType::kCPU, dtype); \ REGISTER_RELU_BACKWARD_KERNEL(DeviceType::kCPU, dtype); REGISTER_ACTIVATION_CPU_KERNEL(float); diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu index a237663f621..c8fcf0dafcf 100644 --- a/oneflow/user/kernels/activation_kernels.cu +++ b/oneflow/user/kernels/activation_kernels.cu @@ -40,6 +40,28 @@ struct EluGradFunctor { EluGradFunctor float_functor; }; +template<> +struct LeakyReluFunctor { + OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) + : alpha(alpha), float_functor(LeakyReluFunctor(alpha)) {} + OF_DEVICE_FUNC half operator()(half x) const { + return __float2half(float_functor(__half2float(x))); + } + const float alpha; + LeakyReluFunctor float_functor; +}; + +template<> +struct LeakyReluGradFunctor { + OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) + : alpha(alpha), float_functor(LeakyReluGradFunctor(alpha)) {} + OF_DEVICE_FUNC half operator()(half x, half dy) const { + return __float2half(float_functor(__half2float(x), __half2float(dy))); + } + const float alpha; + LeakyReluGradFunctor float_functor; +}; + template<> struct CeluFunctor { OF_DEVICE_FUNC explicit CeluFunctor(float alpha) @@ -173,6 +195,7 @@ struct ReluGradFunctor { REGISTER_SILU_KERNEL(DeviceType::kCUDA, dtype); \ REGISTER_SELU_KERNEL(DeviceType::kCUDA, dtype); \ REGISTER_SOFTSIGN_KERNEL(DeviceType::kCUDA, dtype); \ + REGISTER_LEAKYRELU_KERNEL(DeviceType::kCUDA, dtype); \ REGISTER_RELU_BACKWARD_KERNEL(DeviceType::kCUDA, dtype); namespace { diff --git a/oneflow/user/kernels/activation_kernels.h b/oneflow/user/kernels/activation_kernels.h index 578fb1da176..0b43a7712ee 100644 --- a/oneflow/user/kernels/activation_kernels.h +++ b/oneflow/user/kernels/activation_kernels.h @@ -19,6 +19,24 @@ limitations under the License. namespace oneflow { +template +struct LeakyReluFunctor { + OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {} + OF_DEVICE_FUNC T operator()(T x) const { + return (x > static_cast(0)) ? x : static_cast(alpha * x); + } + const T alpha; +}; + +template +struct LeakyReluGradFunctor { + OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {} + OF_DEVICE_FUNC T operator()(T x, T dy) const { + return (x > static_cast(0)) ? dy : static_cast(dy * alpha); + } + const T alpha; +}; + template struct EluFunctor { OF_DEVICE_FUNC explicit EluFunctor(float alpha) : alpha(alpha) {} @@ -232,6 +250,20 @@ struct ReluGradFunctor { }, \ "dx", "x", "dy"); +#define REGISTER_LEAKYRELU_KERNEL(device, dtype) \ + REGISTER_UNARY_ELEMWISE_USER_KERNEL( \ + device, "leaky_relu", LeakyReluFunctor, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { \ + return LeakyReluFunctor(ctx->Attr("alpha")); \ + }, \ + "y", "x"); \ + REGISTER_BINARY_ELEMWISE_USER_KERNEL( \ + device, "leaky_relu_grad", LeakyReluGradFunctor, dtype, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { \ + return LeakyReluGradFunctor(ctx->Attr("alpha")); \ + }, \ + "dx", "x", "dy"); + #define REGISTER_CELU_KERNEL(device, dtype) \ REGISTER_UNARY_ELEMWISE_USER_KERNEL( \ device, "celu", CeluFunctor, dtype, dtype, \ diff --git a/oneflow/user/kernels/leaky_relu_kernel.cpp b/oneflow/user/kernels/leaky_relu_kernel.cpp deleted file mode 100644 index 3a562b431e3..00000000000 --- a/oneflow/user/kernels/leaky_relu_kernel.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" - -namespace oneflow { - -template -class CpuLeakyReluKernel final : public user_op::OpKernel { - public: - CpuLeakyReluKernel() = default; - ~CpuLeakyReluKernel() = default; - - private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); - user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); - const int32_t elem_cnt = x->shape().elem_cnt(); - const float alpha = ctx->Attr("alpha"); - const T* x_ptr = x->dptr(); - T* y_ptr = y->mut_dptr(); - FOR_RANGE(int32_t, i, 0, elem_cnt) { y_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : x_ptr[i] * alpha; } - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_CPU_LEAKY_RELU_KERNEL(dtype) \ - REGISTER_USER_KERNEL("leaky_relu") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ - && (user_op::HobDataType("y", 0) == GetDataType::value)); - -REGISTER_CPU_LEAKY_RELU_KERNEL(float) -REGISTER_CPU_LEAKY_RELU_KERNEL(double) - -template -class CpuLeakyReluGradKernel final : public user_op::OpKernel { - public: - CpuLeakyReluGradKernel() = default; - ~CpuLeakyReluGradKernel() = default; - - private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); - const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); - user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const int32_t elem_cnt = x->shape().elem_cnt(); - const float alpha = ctx->Attr("alpha"); - const T* x_ptr = x->dptr(); - const T* dy_ptr = dy->dptr(); - T* dx_ptr = dx->mut_dptr(); - FOR_RANGE(int32_t, i, 0, elem_cnt) { dx_ptr[i] = x_ptr[i] > 0 ? dy_ptr[i] : dy_ptr[i] * alpha; } - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_CPU_LEAKY_RELU_GRAD_KERNEL(dtype) \ - REGISTER_USER_KERNEL("leaky_relu_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ - && (user_op::HobDataType("dx", 0) == GetDataType::value)); - -REGISTER_CPU_LEAKY_RELU_GRAD_KERNEL(float) -REGISTER_CPU_LEAKY_RELU_GRAD_KERNEL(double) - -} // namespace oneflow diff --git a/oneflow/user/kernels/leaky_relu_kernel.cu b/oneflow/user/kernels/leaky_relu_kernel.cu deleted file mode 100644 index 0ac8a7e61ca..00000000000 --- a/oneflow/user/kernels/leaky_relu_kernel.cu +++ /dev/null @@ -1,108 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" -#include "oneflow/core/device/cuda_util.h" - -namespace oneflow { - -namespace { - -template -__global__ void LeakyReluForwardGpu(const int n, const float alpha, const T* x, T* y) { - CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] > 0 ? x[i] : x[i] * alpha; } -} - -template -__global__ void LeakyReluBackwardGpu(const int n, const float alpha, const T* x, const T* dy, - T* dx) { - CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = x[i] > 0 ? dy[i] : dy[i] * alpha; } -} - -template<> -__global__ void LeakyReluForwardGpu(const int n, const float alpha, const half* x, half* y) { - CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] > static_cast(0) ? x[i] : x[i] * static_cast(alpha); } -} - -template<> -__global__ void LeakyReluBackwardGpu(const int n, const float alpha, const half* x, const half* dy, - half* dx) { - CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = x[i] > static_cast(0) ? dy[i] : dy[i] * static_cast(alpha); } -} - -} // namespace - - - -template -class GpuLeakyReluKernel final : public user_op::OpKernel { - public: - GpuLeakyReluKernel() = default; - ~GpuLeakyReluKernel() = default; - - private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); - user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); - const int32_t elem_cnt = x->shape().elem_cnt(); - const float alpha = ctx->Attr("alpha"); - RUN_CUDA_KERNEL((LeakyReluForwardGpu), ctx->stream(), elem_cnt, elem_cnt, alpha, - x->dptr(), y->mut_dptr()); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_CUDA_LEAKY_RELU_KERNEL(dtype) \ - REGISTER_USER_KERNEL("leaky_relu") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("y", 0) == GetDataType::value)); - -REGISTER_CUDA_LEAKY_RELU_KERNEL(half) -REGISTER_CUDA_LEAKY_RELU_KERNEL(float) -REGISTER_CUDA_LEAKY_RELU_KERNEL(double) - -template -class GpuLeakyReluGradKernel final : public user_op::OpKernel { - public: - GpuLeakyReluGradKernel() = default; - ~GpuLeakyReluGradKernel() = default; - - private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); - const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); - user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const int32_t elem_cnt = x->shape().elem_cnt(); - const float alpha = ctx->Attr("alpha"); - RUN_CUDA_KERNEL((LeakyReluBackwardGpu), ctx->stream(), elem_cnt, elem_cnt, alpha, - x->dptr(), dy->dptr(), dx->mut_dptr()); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_CUDA_LEAKY_RELU_GRAD_KERNEL(dtype) \ - REGISTER_USER_KERNEL("leaky_relu_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("dx", 0) == GetDataType::value)); - -REGISTER_CUDA_LEAKY_RELU_GRAD_KERNEL(half) -REGISTER_CUDA_LEAKY_RELU_GRAD_KERNEL(float) -REGISTER_CUDA_LEAKY_RELU_GRAD_KERNEL(double) - -} // namespace oneflow From fa420ad2e2b1b89eb1b3ca7ba5b2a5acfe68ac63 Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Wed, 23 Feb 2022 09:50:33 +0800 Subject: [PATCH 03/10] fix comment --- oneflow/user/kernels/activation_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/kernels/activation_kernels.h b/oneflow/user/kernels/activation_kernels.h index 0b43a7712ee..942cbfba6e0 100644 --- a/oneflow/user/kernels/activation_kernels.h +++ b/oneflow/user/kernels/activation_kernels.h @@ -23,7 +23,7 @@ template struct LeakyReluFunctor { OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {} OF_DEVICE_FUNC T operator()(T x) const { - return (x > static_cast(0)) ? x : static_cast(alpha * x); + return (x > 0) ? x : alpha * x; } const T alpha; }; @@ -32,7 +32,7 @@ template struct LeakyReluGradFunctor { OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {} OF_DEVICE_FUNC T operator()(T x, T dy) const { - return (x > static_cast(0)) ? dy : static_cast(dy * alpha); + return (x > 0) ? dy : dy * alpha; } const T alpha; }; From cbb403e9536bfabf6af6a543ce6b96da521afa29 Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Wed, 23 Feb 2022 10:06:29 +0800 Subject: [PATCH 04/10] add half leaky relu impl --- oneflow/user/kernels/activation_kernels.cu | 10 ++++++---- oneflow/user/kernels/activation_kernels.h | 8 ++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu index c8fcf0dafcf..745795fff2f 100644 --- a/oneflow/user/kernels/activation_kernels.cu +++ b/oneflow/user/kernels/activation_kernels.cu @@ -44,8 +44,9 @@ template<> struct LeakyReluFunctor { OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha), float_functor(LeakyReluFunctor(alpha)) {} - OF_DEVICE_FUNC half operator()(half x) const { - return __float2half(float_functor(__half2float(x))); + __device__ half operator()(half x) const { + half zero = __float2half(0); + return (x > zero) ? x : __float2half(alpha) * x; } const float alpha; LeakyReluFunctor float_functor; @@ -55,8 +56,9 @@ template<> struct LeakyReluGradFunctor { OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha), float_functor(LeakyReluGradFunctor(alpha)) {} - OF_DEVICE_FUNC half operator()(half x, half dy) const { - return __float2half(float_functor(__half2float(x), __half2float(dy))); + __device__ half operator()(half x, half dy) const { + half zero = __float2half(0); + return (x > zero) ? dy : __float2half(alpha) * dy; } const float alpha; LeakyReluGradFunctor float_functor; diff --git a/oneflow/user/kernels/activation_kernels.h b/oneflow/user/kernels/activation_kernels.h index 942cbfba6e0..e81553e9004 100644 --- a/oneflow/user/kernels/activation_kernels.h +++ b/oneflow/user/kernels/activation_kernels.h @@ -22,18 +22,14 @@ namespace oneflow { template struct LeakyReluFunctor { OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {} - OF_DEVICE_FUNC T operator()(T x) const { - return (x > 0) ? x : alpha * x; - } + OF_DEVICE_FUNC T operator()(T x) const { return (x > 0) ? x : alpha * x; } const T alpha; }; template struct LeakyReluGradFunctor { OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {} - OF_DEVICE_FUNC T operator()(T x, T dy) const { - return (x > 0) ? dy : dy * alpha; - } + OF_DEVICE_FUNC T operator()(T x, T dy) const { return (x > 0) ? dy : dy * alpha; } const T alpha; }; From d32a237101a9b0c6d5238c60e439bc47919c4660 Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Wed, 23 Feb 2022 10:12:09 +0800 Subject: [PATCH 05/10] fix comment --- oneflow/user/kernels/activation_kernels.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu index 745795fff2f..117aab9cfc9 100644 --- a/oneflow/user/kernels/activation_kernels.cu +++ b/oneflow/user/kernels/activation_kernels.cu @@ -43,25 +43,23 @@ struct EluGradFunctor { template<> struct LeakyReluFunctor { OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) - : alpha(alpha), float_functor(LeakyReluFunctor(alpha)) {} + : alpha(alpha) {} __device__ half operator()(half x) const { half zero = __float2half(0); return (x > zero) ? x : __float2half(alpha) * x; } const float alpha; - LeakyReluFunctor float_functor; }; template<> struct LeakyReluGradFunctor { OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) - : alpha(alpha), float_functor(LeakyReluGradFunctor(alpha)) {} + : alpha(alpha) {} __device__ half operator()(half x, half dy) const { half zero = __float2half(0); return (x > zero) ? dy : __float2half(alpha) * dy; } const float alpha; - LeakyReluGradFunctor float_functor; }; template<> From a36481569d6864916348b37e988792a85998472e Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Wed, 23 Feb 2022 10:28:46 +0800 Subject: [PATCH 06/10] fix comment --- oneflow/user/kernels/activation_kernels.cu | 12 +++++++----- oneflow/user/kernels/activation_kernels.h | 8 +++++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu index 117aab9cfc9..b8cf2238e47 100644 --- a/oneflow/user/kernels/activation_kernels.cu +++ b/oneflow/user/kernels/activation_kernels.cu @@ -42,8 +42,7 @@ struct EluGradFunctor { template<> struct LeakyReluFunctor { - OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) - : alpha(alpha) {} + OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {} __device__ half operator()(half x) const { half zero = __float2half(0); return (x > zero) ? x : __float2half(alpha) * x; @@ -53,11 +52,14 @@ struct LeakyReluFunctor { template<> struct LeakyReluGradFunctor { - OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) - : alpha(alpha) {} + OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {} __device__ half operator()(half x, half dy) const { half zero = __float2half(0); - return (x > zero) ? dy : __float2half(alpha) * dy; + if (alpha > 0) { + return (dy > zero) ? dy : __float2half(alpha) * dy; + } else { + return (x > zero) ? dy : __float2half(alpha) * dy; + } } const float alpha; }; diff --git a/oneflow/user/kernels/activation_kernels.h b/oneflow/user/kernels/activation_kernels.h index e81553e9004..b8a0e907438 100644 --- a/oneflow/user/kernels/activation_kernels.h +++ b/oneflow/user/kernels/activation_kernels.h @@ -29,7 +29,13 @@ struct LeakyReluFunctor { template struct LeakyReluGradFunctor { OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {} - OF_DEVICE_FUNC T operator()(T x, T dy) const { return (x > 0) ? dy : dy * alpha; } + OF_DEVICE_FUNC T operator()(T x, T dy) const { + if (alpha > 0) { + return dy > 0 ? dy : dy * alpha; + } else { + return (x > 0) ? dy : dy * alpha; + } + } const T alpha; }; From 5bc4546ae467011db1c3f006eb3e77ea7605f034 Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Wed, 23 Feb 2022 10:34:59 +0800 Subject: [PATCH 07/10] revert --- oneflow/user/kernels/activation_kernels.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu index b8cf2238e47..117aab9cfc9 100644 --- a/oneflow/user/kernels/activation_kernels.cu +++ b/oneflow/user/kernels/activation_kernels.cu @@ -42,7 +42,8 @@ struct EluGradFunctor { template<> struct LeakyReluFunctor { - OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {} + OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) + : alpha(alpha) {} __device__ half operator()(half x) const { half zero = __float2half(0); return (x > zero) ? x : __float2half(alpha) * x; @@ -52,14 +53,11 @@ struct LeakyReluFunctor { template<> struct LeakyReluGradFunctor { - OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {} + OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) + : alpha(alpha) {} __device__ half operator()(half x, half dy) const { half zero = __float2half(0); - if (alpha > 0) { - return (dy > zero) ? dy : __float2half(alpha) * dy; - } else { - return (x > zero) ? dy : __float2half(alpha) * dy; - } + return (x > zero) ? dy : __float2half(alpha) * dy; } const float alpha; }; From 29c253daf4e9c69b48cea09e35e9685f7a2c9ec1 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 23 Feb 2022 03:04:07 +0000 Subject: [PATCH 08/10] auto format by CI --- oneflow/user/kernels/activation_kernels.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu index 117aab9cfc9..b65999cf687 100644 --- a/oneflow/user/kernels/activation_kernels.cu +++ b/oneflow/user/kernels/activation_kernels.cu @@ -42,8 +42,7 @@ struct EluGradFunctor { template<> struct LeakyReluFunctor { - OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) - : alpha(alpha) {} + OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {} __device__ half operator()(half x) const { half zero = __float2half(0); return (x > zero) ? x : __float2half(alpha) * x; @@ -53,8 +52,7 @@ struct LeakyReluFunctor { template<> struct LeakyReluGradFunctor { - OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) - : alpha(alpha) {} + OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {} __device__ half operator()(half x, half dy) const { half zero = __float2half(0); return (x > zero) ? dy : __float2half(alpha) * dy; From 0d3fe7a17d6c15de89f062a836a0fe88afcf267a Mon Sep 17 00:00:00 2001 From: bbuf <1182563586@qq.com> Date: Sat, 26 Feb 2022 10:32:38 +0800 Subject: [PATCH 09/10] fix error --- python/oneflow/test/modules/test_activation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index 8b8f31b83b4..82acec5cfcc 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -587,7 +587,9 @@ def test_leakyrelu_module_with_random_data(test_case): y = m(x) return y + @autotest() + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_leakyrelu_module_with_half_random_data(test_case): m = torch.nn.LeakyReLU(negative_slope=random() | nothing()) m.train(random()) From 3a4cf474f79bd40560b6af3ad13fb4226639ba73 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Sat, 26 Feb 2022 02:37:33 +0000 Subject: [PATCH 10/10] auto format by CI --- python/oneflow/test/modules/test_activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index 82acec5cfcc..1e8aaf14300 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -587,7 +587,6 @@ def test_leakyrelu_module_with_random_data(test_case): y = m(x) return y - @autotest() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_leakyrelu_module_with_half_random_data(test_case):