Skip to content

Commit

Permalink
fix concat_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Mar 25, 2021
1 parent 4b36ee2 commit f07de67
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions paddle/fluid/operators/concat_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));

std::vector<int> sizes;
int offset = 0;
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
Expand All @@ -91,7 +90,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
sizes.push_back(outs[j]->dims()[axis]);
std::vector<int> offsets;
std::vector<int> sizes;
for (int dim = 0; dim < ins[j]->dims().size(); ++dim) {
Expand All @@ -103,9 +101,8 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
sizes.push_back(ins[j]->dims()[dim]);
}
}
auto runner =
NpuOpRunner("SliceD", {*out_grad}, {*outs[j]},
{{"offsets", offset}, {"size", ins[j]->dims()[axis]}});
auto runner = NpuOpRunner("SliceD", {*out_grad}, {*outs[j]},
{{"offsets", offsets}, {"size", sizes}});
runner.Run(stream);
}
if (ins[j]->numel() != 0UL) {
Expand Down

0 comments on commit f07de67

Please sign in to comment.