diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 13f36692abf..128bccdd56e 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -596,7 +596,6 @@ class StackFunctor { int64_t ndims = inputs[0]->ndim(); int64_t stack_dim = dim; stack_dim = JUST(maybe_wrap_dim(stack_dim, ndims + 1)); - if (ninput == 1) { return ExpandDims(inputs[0], dim); } const std::shared_ptr& first_in_shape = inputs[0]->shape(); for (const auto& input : inputs) { for (int i = 0; i < ndims; ++i) { @@ -615,8 +614,13 @@ class StackFunctor { size_t size = (i + kMaxInputCount) < ninput ? kMaxInputCount : ninput - i; TensorTuple partial_inputs(size); for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; } - outputs.emplace_back( - JUST(OpInterpUtil::Dispatch(*ops_.at(size - 1), partial_inputs, attrs))); + if (partial_inputs.size() == 1) { + // Use ExpandDims functor for only one input + outputs.emplace_back(JUST(functional::ExpandDims(partial_inputs[0], dim))); + } else { + outputs.emplace_back( + JUST(OpInterpUtil::Dispatch(*ops_[size - 1], partial_inputs, attrs))); + } } if (outputs.size() == 1) { return outputs.at(0); } return Concat(outputs, stack_dim); diff --git a/python/oneflow/test/modules/test_stack.py b/python/oneflow/test/modules/test_stack.py index cd1f209ce76..f694a28f5c1 100644 --- a/python/oneflow/test/modules/test_stack.py +++ b/python/oneflow/test/modules/test_stack.py @@ -44,6 +44,15 @@ def test_stack_bool_with_random_data(test_case): out = torch.stack((x, y), dim=random(low=1, high=4).to(int)) return out + @autotest(auto_backward=True, check_graph=True) + def test_stack_kMaxInputCount_inputs(test_case): + kMaxInputCount = 128 + 1 + stack_list = [ + random_tensor(ndim=2, dim0=3, dim1=4) for _ in range(kMaxInputCount) + ] + out = torch.stack(stack_list, 0) + return out + if __name__ == "__main__": unittest.main()