Skip to content

Commit

Permalink
[Fix] fixed torch.pow (#420)
Browse files Browse the repository at this point in the history
`torch.pow` accepts following type of operands:
1. (a: Tensor, b: Tensor)
2. (a: Numeric, b: Tensor)  -> missing case that is handled by this PR
3. (a: Tensor, b: Numeric)

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
  • Loading branch information
zhumakhan and Zhumakhan authored Aug 15, 2024
1 parent b1681b7 commit 7314155
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,11 @@ def tensor_where(self: Tensor, condition: Tensor, y: Union[Tensor, Number]):
@register_function(torch.pow)
@register_method(torch.Tensor.pow)
@register_method(torch.Tensor.pow_)
def pow(base: Tensor, exponent: Union[Number, Tensor]):
def torch_pow(base: Union[Number, Tensor], exponent: Union[Number, Tensor]):
if isinstance(exponent, (int, float, bool)):
exponent = full_like(base, exponent)
elif isinstance(base, (int, float, bool)):
base = full_like(exponent, base)
return ops.pow(base, exponent)


Expand Down
14 changes: 13 additions & 1 deletion tests/frontends/torch/test_torch_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,24 @@

@pytest.mark.parametrize('a_shape', [[1, 3, 64], [10, 10], [11, 13], [1, 2, 3]])
@pytest.mark.parametrize('sizes', [[1, 2, 3], [2, 3, 4, 5, 6, 8]])
def test_tensor_repear(a_shape, sizes):
def test_tensor_repeat(a_shape, sizes):
def tensor_repeat(tensor, sizes):
return tensor.repeat(*sizes)

check_module(FunctionalModule(op=tensor_repeat), args=[torch.randn(a_shape), sizes], atol=0, rtol=0)


@pytest.mark.parametrize('a, b', [[[1, 3, 2], [1, 3, 2]], [2.0, [10, 10]], [[11, 13], 2]])
def test_pow(a, b):
if isinstance(a, list) and not isinstance(b, list):
args = [torch.randn(a), b]
elif isinstance(b, list) and not isinstance(a, list):
args = [a, torch.randn(b)]
else:
args = [torch.randn(a), torch.randn(b)]

check_module(FunctionalModule(op=torch.pow), args=args, atol=0.0001, rtol=0.0001)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 7314155

Please sign in to comment.