Skip to content

Commit

Permalink
[NPU] refine lookup_table_v2_grad npu_kernel (#32497)
Browse files Browse the repository at this point in the history
* use ZerosLike instead of NPUMemsetAsync

* fix compile
  • Loading branch information
zhiqiu authored Apr 25, 2021
1 parent 136ef09 commit fb7590d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/operators/lookup_table_v2_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");

auto *output_grad_t =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *table_grad_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
auto *p = table_grad_t->mutable_data<T>(ctx.GetPlace());
table_grad_t->mutable_data<T>(ctx.GetPlace());

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();

platform::NPUMemsetAsync(static_cast<void *>(p), 0,
table_grad_t->numel() * sizeof(T), stream);
auto runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream);

// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
// can be different tensor, but in cann 20.2+, it does inplace operation.
Expand Down

0 comments on commit fb7590d

Please sign in to comment.