Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] fix compute error in fused_dropout_add #52261

Merged
merged 3 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ struct NoMaskBwFunctor {
template <typename T, typename Functor>
__global__ void VectorizedDropoutBackward(const size_t n,
uint64_t seed,
T* src,
T* res,
const T* dst,
T* x,
T* y,
const T* out_grad,
uint64_t increment,
size_t main_offset,
Functor functor) {
Expand All @@ -112,44 +112,38 @@ __global__ void VectorizedDropoutBackward(const size_t n,
#endif

float rands[kCount];
T src_res[kCount * 2];
T res_grad[kCount];
T x_y[kCount * 2];

using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>;

int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount;

for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&src_res[0], dst, deal_size);
kps::ReadData<T, kCount, 1, false>(&x_y[0], out_grad + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, false>(src + fix, &src_res[0], deal_size);
// res
kps::ElementwiseUnary<T, T, kCount, 1, Cast>(
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, false>(res + fix, &res_grad[0], deal_size);
&x_y[0], &x_y[0], &rands[0], functor, kCount);

kps::WriteData<T, kCount, 1, false>(x + fix, &x_y[0], deal_size);
kps::WriteData<T, kCount, 1, false>(y + fix, &x_y[kCount], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}

int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, true>(&src_res[0], dst + fix, remainder);
kps::ReadData<T, kCount, 1, true>(&x_y[0], out_grad + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, true>(src + fix, &src_res[0], remainder);
&x_y[0], &x_y[0], &rands[0], functor, kCount);

// res
kps::ElementwiseUnary<T, T, kCount, 1, Cast>(
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, true>(res + fix, &res_grad[0], remainder);
kps::WriteData<T, kCount, 1, true>(x + fix, &x_y[0], remainder);
kps::WriteData<T, kCount, 1, true>(y + fix, &x_y[kCount], remainder);
__syncthreads();
}
}
Expand Down Expand Up @@ -201,7 +195,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
size_t block_size = random_prop[1];
size_t offset = random_prop[2];
size_t main_offset = random_prop[3];

auto functor = upscale_in_train
? NoMaskBwFunctor<T, float>(1.0f - dropout_rate)
: NoMaskBwFunctor<T, float>(1.0f - dropout_rate, 1.0f);
Expand Down
15 changes: 7 additions & 8 deletions python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"):
)
class TestFusedDropoutAdd(unittest.TestCase):
def setUp(self):
self.shape = (2, 10, 10, 2)
self.dtype = 'float64'
self.dropout_rate = 0.9
self.shape = [2, 1024, 2, 1]
self.dtype = 'float16'
self.dropout_rate = 0.5
self.training = True
self.mode = "upscale_in_train"
self.seed = 1027
Expand Down Expand Up @@ -66,9 +66,8 @@ def get_forward_backward(self, dropout_add, seed):
mode=self.mode,
)
fw.append(out)

loss = paddle.mean(out)
loss.backward()
out_g = paddle.randn(self.shape, self.dtype)
paddle.autograd.backward([out], [out_g], True)
for i in range(count):
bw.append(data[i].grad)
return fw, bw
Expand All @@ -95,7 +94,7 @@ def create_test_class(parent, dtype, mode, training, p, seed):
)
class TestFusedDropoutAddCase(parent):
def setUp(self):
self.shape = (2, 10, 10, 2)
self.shape = (2, 1024, 1, 1)
self.dtype = dtype
self.dropout_rate = p
self.training = training
Expand Down Expand Up @@ -168,7 +167,7 @@ def test_fused_dropout_add_layer(self):
y = paddle.randn(self.shape, self.dtype)
fused_d_a = FusedDropoutAdd(p=0.5)
d = paddle.nn.Dropout(p=0.5)
print(d)
print(d.extra_repr())
paddle.seed(2048)
fused_out = fused_d_a(x, y)
paddle.seed(2048)
Expand Down