-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 4】No.56 : add fp16 test and bf16 for bernoulli and trunc #51657
Changes from 13 commits
a1d0522
f6455e7
3279c68
99f5854
9ee7d3a
dce1754
63c6f39
b1771eb
528e5b8
2fc39e1
8b8361d
099d3bb
22dbf8d
9db702f
38d7bc1
f4ce773
bd62029
3782bd1
b20ac1a
7def562
3f44c3d
3e9063a
13a2c74
3c4e333
e7ad7f2
10336f8
f922dd8
ea1d0ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,14 +23,39 @@ namespace phi { | |
|
||
using phi::PADDLE_CUDA_NUM_THREADS; | ||
|
||
template <typename T> | ||
__device__ T device_trunc(T x); | ||
|
||
template <> | ||
__device__ float device_trunc<float>(float x) { | ||
return truncf(x); | ||
} | ||
|
||
template <> | ||
__device__ double device_trunc<double>(double x) { | ||
return trunc(x); | ||
} | ||
|
||
template <> | ||
__device__ phi::dtype::float16 device_trunc<phi::dtype::float16>( | ||
phi::dtype::float16 x) { | ||
return static_cast<phi::dtype::float16>(truncf(static_cast<float>(x))); | ||
} | ||
|
||
template <> | ||
__device__ phi::dtype::bfloat16 device_trunc<phi::dtype::bfloat16>( | ||
phi::dtype::bfloat16 x) { | ||
return static_cast<phi::dtype::bfloat16>(truncf(static_cast<float>(x))); | ||
} | ||
|
||
template <typename T> | ||
class TruncFunctor { | ||
public: | ||
__device__ TruncFunctor(const T x) : x_(x) {} | ||
__device__ T operator()() { return trunc(x_); } | ||
__device__ TruncFunctor(T x) : x_(x) {} | ||
__device__ T operator()() { return device_trunc(x_); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉也是可以直接用MPType来计算 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的👌 |
||
|
||
public: | ||
const T x_; | ||
T x_; | ||
}; | ||
|
||
template <> | ||
|
@@ -78,5 +103,13 @@ void TruncKernel(const Context& dev_ctx, | |
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
trunc, GPU, ALL_LAYOUT, phi::TruncKernel, float, double, int, int64_t) {} | ||
PD_REGISTER_KERNEL(trunc, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::TruncKernel, | ||
float, | ||
double, | ||
int, | ||
int64_t, | ||
phi::dtype::float16, | ||
phi::dtype::bfloat16) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,9 +15,10 @@ | |
import unittest | ||
|
||
import numpy as np | ||
from eager_op_test import OpTest | ||
from eager_op_test import OpTest, convert_float_to_uint16 | ||
|
||
import paddle | ||
from paddle.fluid import core | ||
|
||
|
||
def output_hist(out): | ||
|
@@ -31,10 +32,15 @@ def output_hist(out): | |
class TestBernoulliOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "bernoulli" | ||
self.inputs = {"X": np.random.uniform(size=(1000, 784))} | ||
self.inputs = { | ||
"X": np.random.uniform(size=(1000, 784)).astype(self.dtype) | ||
} | ||
self.attrs = {} | ||
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float16的输出不应该是float32类型吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的👌 |
||
|
||
def init_dtype(self): | ||
self.dtype = np.float32 | ||
|
||
def test_check_output(self): | ||
self.check_output_customized(self.verify_output) | ||
|
||
|
@@ -98,5 +104,39 @@ def test_fixed_random_number(self): | |
paddle.enable_static() | ||
|
||
|
||
class TestBernoulliFP16Op(TestBernoulliOp): | ||
def init_dtype(self): | ||
self.dtype = np.float16 | ||
|
||
|
||
@unittest.skipIf( | ||
not core.is_compiled_with_cuda() | ||
or not core.is_bfloat16_supported(core.CUDAPlace(0)), | ||
"core is not complied with CUDA and not support the bfloat16", | ||
) | ||
class TestBernoulliBF16Op(OpTest): | ||
def setUp(self): | ||
self.python_api = paddle.bernoulli | ||
self.op_type = "bernoulli" | ||
self.dtype = np.uint16 | ||
self.init_test_case() | ||
|
||
self.inputs = {'X': convert_float_to_uint16(self.x)} | ||
self.attrs = {} | ||
self.outputs = {'Out': convert_float_to_uint16(self.out)} | ||
|
||
def test_check_output(self): | ||
place = core.CUDAPlace(0) | ||
self.check_output_with_place_customized(self.verify_output, place) | ||
|
||
def init_test_case(self): | ||
self.x = np.random.uniform(size=(1000, 784)).astype("float32") | ||
self.out = np.zeros((1000, 784)).astype("float32") | ||
|
||
def verify_output(self, outs): | ||
hist, prob = output_hist(np.array(outs[0])) | ||
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.01) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉直接使用MPType,然后把x_data[idx]做个cast就可以?
out_data[idx] = static_cast<T>((&rand.x)[j] <= static_cast<MPType>(x_data[idx]));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的👌