From 9bf8090edd1d70d710d780f4dd689ebda096bd6d Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Wed, 15 Jun 2022 02:02:42 +0800 Subject: [PATCH] fix repeat interleave 0-size tensor bug (#8414) Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .../core/functional/impl/array_functor.cpp | 16 ++---------- .../test/modules/test_repeat_interleave.py | 26 +++++++++++++++++-- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index baf634f1679..5bfa9b45417 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -2998,20 +2998,8 @@ class RepeatInterLeaveTensorFunctor { std::shared_ptr cumsum = JUST(Cumsum(repeats, 0, DType::Int32())); const int64_t& output_size_value = std::accumulate(repeats_value.begin(), repeats_value.end(), 0); - std::shared_ptr res; - if (output_size_value > 0) { - res = JUST(IndexSelect(input, dim_, - JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value)))); - } else { - // Deal with 0-size Tensor. - DimVector new_input_shape(input_shape->dim_vec().begin(), input_shape->dim_vec().end()); - new_input_shape[dim_] = 0; - std::shared_ptr new_input = - JUST(Constant(Shape{new_input_shape}, Scalar(0), input->dtype(), JUST(input->device()))); - res = JUST(IndexSelect(new_input, dim_, - JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value)))); - } - return res; + return JUST( + IndexSelect(input, dim_, JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value)))); } }; diff --git a/python/oneflow/test/modules/test_repeat_interleave.py b/python/oneflow/test/modules/test_repeat_interleave.py index 95faea06ac5..5a636f0e66c 100644 --- a/python/oneflow/test/modules/test_repeat_interleave.py +++ b/python/oneflow/test/modules/test_repeat_interleave.py @@ -15,8 +15,10 @@ """ import unittest +import numpy as np import oneflow as flow import oneflow.unittest +import torch as torch_original from oneflow.test_utils.automated_test_util import * @@ -39,17 +41,37 @@ def test_flow_int_repeat_interleave_with_dim(test_case): @autotest(n=5) def test_flow_tensor_repeat_interleave_dim(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) - y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4) + y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4) z = torch.repeat_interleave(x, y, 1) return z @autotest(n=5) def test_flow_tensor_repeat_interleave_dim_with_output_size(test_case): x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3) - y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4) + y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4) z = torch.repeat_interleave(x, y, 1, output_size=2) return z + def test_flow_tensor_repeat_interleave_0size_tensor(test_case): + np_arr = np.array( + [ + [[0.8548, 0.0436, 0.7977], [0.1919, 0.4191, 0.2186]], + [[0.4741, 0.8896, 0.6859], [0.5223, 0.7803, 0.1134]], + ] + ) + x_torch = torch_original.tensor(np_arr) + x_torch.requires_grad = True + y_torch = torch_original.tensor([0, 0]) + z_torch = torch_original.repeat_interleave(x_torch, y_torch, 1) + z_torch.sum().backward() + + x_flow = flow.tensor(np_arr) + x_flow.requires_grad = True + y_flow = flow.tensor([0, 0]) + z_flow = flow.repeat_interleave(x_flow, y_flow, 1) + z_flow.sum().backward() + test_case.assertTrue(np.array_equal(x_torch.grad.numpy(), x_flow.grad.numpy())) + if __name__ == "__main__": unittest.main()