Skip to content

Commit

Permalink
fix repeat interleave 0-size tensor bug (#8414)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
BBuf and mergify[bot] authored Jun 14, 2022
1 parent 041f787 commit 9bf8090
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
16 changes: 2 additions & 14 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2998,20 +2998,8 @@ class RepeatInterLeaveTensorFunctor {
std::shared_ptr<one::Tensor> cumsum = JUST(Cumsum(repeats, 0, DType::Int32()));
const int64_t& output_size_value =
std::accumulate(repeats_value.begin(), repeats_value.end(), 0);
std::shared_ptr<one::Tensor> res;
if (output_size_value > 0) {
res = JUST(IndexSelect(input, dim_,
JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value))));
} else {
// Deal with 0-size Tensor.
DimVector new_input_shape(input_shape->dim_vec().begin(), input_shape->dim_vec().end());
new_input_shape[dim_] = 0;
std::shared_ptr<one::Tensor> new_input =
JUST(Constant(Shape{new_input_shape}, Scalar(0), input->dtype(), JUST(input->device())));
res = JUST(IndexSelect(new_input, dim_,
JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value))));
}
return res;
return JUST(
IndexSelect(input, dim_, JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value))));
}
};

Expand Down
26 changes: 24 additions & 2 deletions python/oneflow/test/modules/test_repeat_interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""
import unittest

import numpy as np
import oneflow as flow
import oneflow.unittest
import torch as torch_original

from oneflow.test_utils.automated_test_util import *

Expand All @@ -39,17 +41,37 @@ def test_flow_int_repeat_interleave_with_dim(test_case):
@autotest(n=5)
def test_flow_tensor_repeat_interleave_dim(test_case):
x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)
y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4)
y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4)
z = torch.repeat_interleave(x, y, 1)
return z

@autotest(n=5)
def test_flow_tensor_repeat_interleave_dim_with_output_size(test_case):
x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)
y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4)
y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4)
z = torch.repeat_interleave(x, y, 1, output_size=2)
return z

def test_flow_tensor_repeat_interleave_0size_tensor(test_case):
np_arr = np.array(
[
[[0.8548, 0.0436, 0.7977], [0.1919, 0.4191, 0.2186]],
[[0.4741, 0.8896, 0.6859], [0.5223, 0.7803, 0.1134]],
]
)
x_torch = torch_original.tensor(np_arr)
x_torch.requires_grad = True
y_torch = torch_original.tensor([0, 0])
z_torch = torch_original.repeat_interleave(x_torch, y_torch, 1)
z_torch.sum().backward()

x_flow = flow.tensor(np_arr)
x_flow.requires_grad = True
y_flow = flow.tensor([0, 0])
z_flow = flow.repeat_interleave(x_flow, y_flow, 1)
z_flow.sum().backward()
test_case.assertTrue(np.array_equal(x_torch.grad.numpy(), x_flow.grad.numpy()))


if __name__ == "__main__":
unittest.main()

0 comments on commit 9bf8090

Please sign in to comment.