Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support leaky_relu half kernel #7569

Merged
merged 25 commits into from
Feb 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1af8077
support leaky_relu half kernel
BBuf Feb 22, 2022
ae5acc8
add half leaky relu
BBuf Feb 23, 2022
fa420ad
fix comment
BBuf Feb 23, 2022
cbb403e
add half leaky relu impl
BBuf Feb 23, 2022
d32a237
fix comment
BBuf Feb 23, 2022
a364815
fix comment
BBuf Feb 23, 2022
5bc4546
revert
BBuf Feb 23, 2022
8f942c9
Merge branch 'master' into add_half_leaky_relu
BBuf Feb 23, 2022
29c253d
auto format by CI
oneflow-ci-bot Feb 23, 2022
3367e86
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 23, 2022
8d9a4a9
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 23, 2022
3a4dfec
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 23, 2022
d991e19
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 23, 2022
ecbb38c
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 23, 2022
1c811be
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 23, 2022
2ad4f34
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 24, 2022
c5a856f
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 24, 2022
2b87e87
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 24, 2022
d77be21
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 24, 2022
c6bb208
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 24, 2022
1e6d489
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 24, 2022
0d3fe7a
fix error
BBuf Feb 26, 2022
68db300
Merge branch 'master' into add_half_leaky_relu
BBuf Feb 26, 2022
3a4cf47
auto format by CI
oneflow-ci-bot Feb 26, 2022
e854237
Merge branch 'master' into add_half_leaky_relu
oneflow-ci-bot Feb 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions oneflow/user/kernels/activation_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace oneflow {
REGISTER_SILU_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_SELU_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_SOFTSIGN_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_LEAKYRELU_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_RELU_BACKWARD_KERNEL(DeviceType::kCPU, dtype);

REGISTER_ACTIVATION_CPU_KERNEL(float);
Expand Down
21 changes: 21 additions & 0 deletions oneflow/user/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@ struct EluGradFunctor<half> {
EluGradFunctor<float> float_functor;
};

template<>
struct LeakyReluFunctor<half> {
OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {}
__device__ half operator()(half x) const {
half zero = __float2half(0);
return (x > zero) ? x : __float2half(alpha) * x;
}
const float alpha;
};

template<>
struct LeakyReluGradFunctor<half> {
OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {}
__device__ half operator()(half x, half dy) const {
half zero = __float2half(0);
return (x > zero) ? dy : __float2half(alpha) * dy;
}
const float alpha;
};

template<>
struct CeluFunctor<half> {
OF_DEVICE_FUNC explicit CeluFunctor(float alpha)
Expand Down Expand Up @@ -173,6 +193,7 @@ struct ReluGradFunctor<half> {
REGISTER_SILU_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_SELU_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_SOFTSIGN_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_LEAKYRELU_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_RELU_BACKWARD_KERNEL(DeviceType::kCUDA, dtype);

namespace {
Expand Down
34 changes: 34 additions & 0 deletions oneflow/user/kernels/activation_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@ limitations under the License.

namespace oneflow {

template<typename T>
struct LeakyReluFunctor {
OF_DEVICE_FUNC explicit LeakyReluFunctor(float alpha) : alpha(alpha) {}
OF_DEVICE_FUNC T operator()(T x) const { return (x > 0) ? x : alpha * x; }
const T alpha;
};

template<typename T>
struct LeakyReluGradFunctor {
OF_DEVICE_FUNC explicit LeakyReluGradFunctor(float alpha) : alpha(alpha) {}
OF_DEVICE_FUNC T operator()(T x, T dy) const {
if (alpha > 0) {
return dy > 0 ? dy : dy * alpha;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是要写 y > 0 ? dy : dy * alpha 么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,神奇,我为什么能通过单测。。。我修一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是x>0 ? dy : dy * alpha

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y和x大于0一样的,这里特判了一下alpha>0,这个特判没有必要我去掉一下。

} else {
return (x > 0) ? dy : dy * alpha;
}
}
const T alpha;
};

template<typename T>
struct EluFunctor {
OF_DEVICE_FUNC explicit EluFunctor(float alpha) : alpha(alpha) {}
Expand Down Expand Up @@ -232,6 +252,20 @@ struct ReluGradFunctor {
}, \
"dx", "x", "dy");

#define REGISTER_LEAKYRELU_KERNEL(device, dtype) \
REGISTER_UNARY_ELEMWISE_USER_KERNEL( \
device, "leaky_relu", LeakyReluFunctor, dtype, dtype, \
[](user_op::KernelComputeContext* ctx) { \
return LeakyReluFunctor<dtype>(ctx->Attr<float>("alpha")); \
}, \
"y", "x"); \
REGISTER_BINARY_ELEMWISE_USER_KERNEL( \
device, "leaky_relu_grad", LeakyReluGradFunctor, dtype, dtype, dtype, \
[](user_op::KernelComputeContext* ctx) { \
return LeakyReluGradFunctor<dtype>(ctx->Attr<float>("alpha")); \
}, \
"dx", "x", "dy");

#define REGISTER_CELU_KERNEL(device, dtype) \
REGISTER_UNARY_ELEMWISE_USER_KERNEL( \
device, "celu", CeluFunctor, dtype, dtype, \
Expand Down
78 changes: 0 additions & 78 deletions oneflow/user/kernels/leaky_relu_kernel.cpp

This file was deleted.

93 changes: 0 additions & 93 deletions oneflow/user/kernels/leaky_relu_kernel.cu

This file was deleted.

12 changes: 12 additions & 0 deletions python/oneflow/test/modules/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,18 @@ def test_leakyrelu_module_with_random_data(test_case):
y = m(x)
return y

@autotest()
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
def test_leakyrelu_module_with_half_random_data(test_case):
m = torch.nn.LeakyReLU(negative_slope=random() | nothing())
m.train(random())
device = random_device()
m.to(device)
x = random_tensor().to(device)
x = x.to(torch.float16)
y = m(x)
return y

@autotest()
def test_leakyrelu_module_with_0dim_data(test_case):
m = torch.nn.LeakyReLU(negative_slope=random() | nothing())
Expand Down