Skip to content

Commit

Permalink
Fix sequence expand op (#11618)
Browse files Browse the repository at this point in the history
* Set zero outside functor
  • Loading branch information
Qingsheng Li authored Jun 27, 2018
1 parent 01fbcb0 commit 8630ba2
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions paddle/fluid/operators/sequence_expand_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* dx) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
set_zero(context, dx, static_cast<T>(0));

int dout_offset = 0;
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
Expand Down Expand Up @@ -187,6 +184,10 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
g_x->mutable_data<T>(context.GetPlace());
g_x->set_lod(x->lod());

auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, g_x, static_cast<T>(0));

auto& y_lod = y->lod();
if (ref_level == -1) ref_level = y_lod.size() - 1;
// just copy the gradient
Expand Down

0 comments on commit 8630ba2

Please sign in to comment.