From 093056a7a5387cba89204445333980b933f7ef93 Mon Sep 17 00:00:00 2001 From: Shijie <821898965@qq.com> Date: Wed, 26 Jan 2022 14:47:49 +0800 Subject: [PATCH] Fix stack backward (#7363) * fix typo * fix stack backward Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/core/autograd/gradient_funcs/stack.cpp | 12 ++++-------- oneflow/core/functional/impl/array_functor.cpp | 1 + 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/oneflow/core/autograd/gradient_funcs/stack.cpp b/oneflow/core/autograd/gradient_funcs/stack.cpp index 278b58b2842..20e0d11dbbe 100644 --- a/oneflow/core/autograd/gradient_funcs/stack.cpp +++ b/oneflow/core/autograd/gradient_funcs/stack.cpp @@ -65,14 +65,10 @@ Maybe Stack::Apply(const StackCaptureState* ctx, const TensorTuple& out_gr in_grads->resize(ctx->input_num); TensorTuple like(ctx->input_num); for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); } - if (ctx->input_num == 1) { - in_grads->at(0) = out_grads.at(0); - } else { - const auto& results = JUST(functional::StackGrad(out_grads.at(0), like, ctx->axis)); - CHECK_EQ_OR_RETURN(results->size(), ctx->input_num); - for (int i = 0; i < ctx->input_num; ++i) { - if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); } - } + const auto& results = JUST(functional::StackGrad(out_grads.at(0), like, ctx->axis)); + CHECK_EQ_OR_RETURN(results->size(), ctx->input_num); + for (int i = 0; i < ctx->input_num; ++i) { + if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); } } return Maybe::Ok(); } diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 1ecc49d1747..502d29cf223 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -524,6 +524,7 @@ class StackFunctor { CHECK_OR_RETURN(stack_dim >= 0 && stack_dim <= ndims) << "Index Error: Dimension out of range (expected in range of [" << -ndims - 1 << ", " << ndims << "], but got " << stack_dim; + if (ninput == 1) { return ExpandDims(inputs[0], dim); } const std::shared_ptr& first_in_shape = inputs[0]->shape(); for (const auto& input : inputs) { for (int i = 0; i < ndims; ++i) {