Skip to content

Commit

Permalink
[Operator] Adding support for torch.Tensor.view_as (#334)
Browse files Browse the repository at this point in the history
Closes #333
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Jul 22, 2024
1 parent 70d1bb2 commit 06ab1a0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def tensor_view(self: Tensor, *args) -> Tensor:
return ops.reshape(self, dst_shape)


@register_method(torch.Tensor.view_as)
def torch_view_as(self: Tensor, other: Tensor) -> Tensor:
return ops.reshape(self, other.shape)


@register_method(torch.Tensor.contiguous)
def tensor_contiguous(self: Tensor) -> Tensor:
# hidet tensor is always contiguous
Expand Down
7 changes: 7 additions & 0 deletions tests/frontends/torch/test_torch_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_expand_as(shape, expanded_shape):
)


@pytest.mark.parametrize('shape, new_shape', [[[2, 3, 4], [6, 4]], [[2, 3, 4], [12, 2]]])
def test_view_as(shape, new_shape):
check_module(
FunctionalModule(op=lambda x: x.view_as(torch.randn(new_shape))), args=[torch.randn(shape)], atol=0, rtol=0
)


@pytest.mark.parametrize('shape', [[2, 3]])
def test_tensor_sigmod(shape):
check_module(FunctionalModule(op=lambda x: x.sigmoid_()), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)
Expand Down

0 comments on commit 06ab1a0

Please sign in to comment.