From f07de67a814ab5a67726fbfec7e4116f2a9c2ba4 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 25 Mar 2021 11:19:03 +0000 Subject: [PATCH] fix concat_grad --- paddle/fluid/operators/concat_op_npu.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/concat_op_npu.cc b/paddle/fluid/operators/concat_op_npu.cc index 9b979dede048f..87bb3397ca267 100644 --- a/paddle/fluid/operators/concat_op_npu.cc +++ b/paddle/fluid/operators/concat_op_npu.cc @@ -80,7 +80,6 @@ class ConcatGradNPUKernel : public framework::OpKernel { axis = ComputeAxis(static_cast(axis), static_cast(ins[0]->dims().size())); - std::vector sizes; int offset = 0; auto stream = ctx.template device_context() @@ -91,7 +90,6 @@ class ConcatGradNPUKernel : public framework::OpKernel { if (out_var_names[j] != framework::kEmptyVarName && outs[j]->numel() != 0UL) { outs[j]->mutable_data(ctx.GetPlace()); - sizes.push_back(outs[j]->dims()[axis]); std::vector offsets; std::vector sizes; for (int dim = 0; dim < ins[j]->dims().size(); ++dim) { @@ -103,9 +101,8 @@ class ConcatGradNPUKernel : public framework::OpKernel { sizes.push_back(ins[j]->dims()[dim]); } } - auto runner = - NpuOpRunner("SliceD", {*out_grad}, {*outs[j]}, - {{"offsets", offset}, {"size", ins[j]->dims()[axis]}}); + auto runner = NpuOpRunner("SliceD", {*out_grad}, {*outs[j]}, + {{"offsets", offsets}, {"size", sizes}}); runner.Run(stream); } if (ins[j]->numel() != 0UL) {