Skip to content

Commit

Permalink
disable EmbeddingDenseGrad temporarily
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Jul 30, 2021
1 parent eb27d8b commit ebfc431
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions paddle/fluid/operators/lookup_table_v2_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();

int embedding_dim = table_grad_t->dims()[1];
/* EmbeddingDenseGrad has bug on large shape, temporarily disable it.
int embedding_dim = table_grad_t->dims()[1];
if (embedding_dim % 32 == 0) {
// NOTE(pangyoki): The embedding_dim of Tensor used in
// EmbeddingDenseGrad must be an integer multiple of 32.
Expand All @@ -81,19 +82,21 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
{"padding_idx", -1},
{"scale_grad_by_freq", false}});
runner.Run(stream);
} else {
const auto &runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream);

// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
// can be different tensor, but in cann 20.2+, it does inplace operation.
// Thus, the first input and output should be same tensor.
const auto &runner_scatter =
NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t},
{*table_grad_t}, {{"use_locking", true}});
runner_scatter.Run(stream);
return;
}
*/

const auto &runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream);

// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
// can be different tensor, but in cann 20.2+, it does inplace operation.
// Thus, the first input and output should be same tensor.
const auto &runner_scatter =
NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t},
{*table_grad_t}, {{"use_locking", true}});
runner_scatter.Run(stream);
}
};
} // namespace operators
Expand Down

1 comment on commit ebfc431

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.