Skip to content

Commit

Permalink
[Operator] Registering torch.Tensor.copy_ (#259)
Browse files Browse the repository at this point in the history
Closes #247 

Also registered `torch.Tensor.rsqrt`, which was encountered in the same
model right after adding `torch.Tensor.copy_`
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Jul 22, 2024
1 parent 836e3ac commit d0877d5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,7 @@ def rshift(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Te


@register_function(torch.rsqrt)
@register_method(torch.Tensor.rsqrt)
def rsqrt(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.rsqrt(..., out=...)")
Expand Down Expand Up @@ -1340,6 +1341,15 @@ def torch_clone(x: Tensor, *, memory_format=torch.preserve_format):
return x.copy()


@register_method(torch.Tensor.copy_)
def torch_copy(x: Tensor, src: Tensor, non_blocking: bool = False):
if non_blocking:
warnings.warn_once("torch.Tensor.copy_ with non_blocking=True is not supported. Treating as non_blocking=False")
if x.shape != src.shape:
src = ops.broadcast(src, x.shape)
return torch_clone(src)


@register_function(torch.chunk)
def torch_chunk(x: Tensor, chunks: int, dim: int = 0):
return ops.split(x, parts_or_sections=chunks, axis=dim)
Expand Down
18 changes: 18 additions & 0 deletions tests/frontends/torch/test_torch_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,23 @@ def test_tensor_sigmod(shape):
check_module(FunctionalModule(op=lambda x: x.sigmoid_()), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize(
'shape,src_shape',
[
([2, 3, 4], [2, 3, 4]),
([2, 3, 4], [2, 3, 1]),
([2, 3, 4], [2, 1, 1]),
([2, 3, 4], [1, 1]),
([2, 3, 4], [1]),
([5, 3, 4, 1], [3, 1, 1]),
([5, 3, 4, 1], [4, 1]),
],
)
def test_torch_copy(shape, src_shape):
check_module(
FunctionalModule(op=lambda x, y: x.copy_(y)), args=[torch.randn(shape), torch.randn(src_shape)], atol=0, rtol=0
)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit d0877d5

Please sign in to comment.