Skip to content

Commit

Permalink
[BUG] Support concat empty tensors (#475)
Browse files Browse the repository at this point in the history
Refer to [this ](#440).

The user tried to concat an empty tensor to a non-empty one, and
expecting the same behavior as pytorch.
  • Loading branch information
ZichuWu authored Sep 23, 2024
1 parent 9f3e2b9 commit d4cf4da
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/hidet/graph/ops/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def __init__(self, *tensors: Tensor, axis: int):
tensors = list(tensors)
if len(tensors) == 0:
raise ValueError('Concat requires at least one tensor, 0 given.')
tensors = [tensor for tensor in tensors if tensor.shape != (0,)] or [tensors[0]]
axis = normalize_dim(axis, len(tensors[0].shape))
super().__init__(
inputs=tensors,
Expand Down
16 changes: 16 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ def test_concat(shapes, dtype, axis):
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, rtol=0, atol=0)


@pytest.mark.parametrize(
"shapes, dtype, axis",
[
[[[0], [0], [33, 44, 55], [0], [1, 44, 55], [0], [32, 44, 55], [0]], 'float32', 0],
[[[0], [0], [33, 1, 55], [0], [33, 8, 55], [0], [33, 111, 55], [0]], 'float32', 1],
[[[0], [0], [33, 1, 55], [0], [33, 8, 55], [0], [33, 111, 55], [0]], 'float32', -2],
[[[0], [0], [0]], 'float32', 0],
],
)
def test_concat_empty(shapes, dtype, axis):
data_list = [np.random.randn(*shape).astype(dtype) for shape in shapes]
torch_result = torch.cat([torch.asarray(d) for d in data_list], dim=axis)
hidet_result = ops.concat([hi.asarray(data).cuda() for data in data_list], axis).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=torch_result.cpu().numpy(), atol=0, rtol=0)


@pytest.mark.parametrize("shape, src_type, dst_type", [[[33, 44, 55], "int64", "float32"]])
def test_cast(shape, src_type, dst_type):
check_transform(shape, lambda x: x.astype(dst_type), lambda x: ops.cast(x, dst_type), dtype=src_type)
Expand Down

0 comments on commit d4cf4da

Please sign in to comment.